fl4health.clients.scaffold_client module

class DPScaffoldClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: ScaffoldClient, InstanceLevelDpClient

Federated Learning client for Instance Level Differentially Private Scaffold strategy

Implemented as specified in https://arxiv.org/abs/2111.09278

class ScaffoldClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: BasicClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Federated Learning Client for Scaffold strategy.

Implementation based on https://arxiv.org/pdf/1910.06378.pdf.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

compute_parameters_delta(params_1, params_2)[source]

Computes element-wise difference of two lists of NDarray where elements in params_2 are subtracted from elements in params_1

Each NDArray in the list of NDArrays are subtracted as

\[\text{params}_{1, i} - \text{params}_{2, i}\]
Parameters:
  • params_1 (NDArrays) – First set of parameters

  • params_2 (NDArrays) – Second set of parameters

Returns:

\(\text{params}_1 - \text{params}_2\)

Return type:

NDArrays

compute_updated_control_variates(local_steps, delta_model_weights, delta_control_variates)[source]

Computes the updated local control variates according to option 2 in Equation 4 of paper. The calculation is

\[c_i^+ = c_i - c + \frac{1}{(K \cdot lr)} \cdot (x - y_i)\]

where lr is the local learning rate.

Parameters:
  • local_steps (int) – Number of local steps that were taken during local training (\(K\))

  • delta_model_weights (NDArrays) – difference between the locally trained weights and the initial weights prior to local training

  • delta_control_variates (NDArrays) – difference between local (\(c_i\)) and server (\(c\)) control variates \(c_i - c\).

Returns:

Updated client control variates

Return type:

NDArrays

get_parameter_exchanger(config)[source]

Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.

Parameters:

config (Config) – The config from server.

Returns:

Used to exchange parameters between server and client.

Return type:

ParameterExchanger

get_parameters(config)[source]

Packs the parameters and control variates into a single NDArrays to be sent to the server for aggregation

Parameters:

config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Returns:

Model parameters and control variates packed together.

Return type:

NDArrays

modify_grad()[source]

Modifies the gradient of the local model to correct for client drift. To be called after the gradients have been computed on a batch of data. Updates not applied to params until step is called on optimizer.

Return type:

None

set_parameters(parameters, config, fitting_round)[source]

Assumes that the parameters being passed contain model parameters concatenated with server control variates. They are unpacked for the clients to use in training. If it’s the first time the model is being initialized, we assume the full model is being initialized and use the FullParameterExchanger() to set all model weights

Parameters:
  • parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model and also the server control variates (initial or after aggregation)

  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Return type:

None

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. Extends the basic client to extract the learning rate from the optimizer and set the learning_rate attribute (used to compute updated control variates).

Parameters:

config (Config) – The config from the server.

Return type:

None

transform_gradients(losses)[source]

Hook function for model training only called after backwards pass but before optimizer step. Used to modify gradient to correct for client drift in Scaffold.

Parameters:

losses (TrainingLosses) – losses is not used in this transformation.

Return type:

None

update_after_train(local_steps, loss_dict, config)[source]

Called after training with the number of local_steps performed over the FL round and the corresponding loss dictionary.

Parameters:
  • local_steps (int) – Number of local steps that were taken during local training (\(K\))

  • loss_dict (dict[str, float]) – dictionary of losses computed during training

  • config (Config) – The config from the server.

Return type:

None

update_control_variates(local_steps)[source]

Updates local control variates along with the corresponding updates according to the option 2 in Equation 4 in https://arxiv.org/pdf/1910.06378.pdf

To be called after weights of local model have been updated.

Parameters:

local_steps (int) – Number of local steps performed during training.

Return type:

None