fl4health.clients.apfl_client module¶
- class ApflClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Client specifically implementing the APFL Algorithm: https://arxiv.org/abs/2003.13461 Twin models are trained. One of them is globally shared by all clients and aggregated on the server. The other is strictly trained locally by each client. Predictions are made by a convex combination of the models.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False
client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
- compute_loss_and_additional_losses(preds, features, target)[source]¶
Computes the loss and any additional losses given predictions of the model and ground truth data. For APFL, the loss will be the personal loss and the additional losses are the global and local loss.
- Parameters:
- Returns:
The tensor for the personal loss
A dictionary of with global_loss and local_loss keys and their calculated values
- Return type:
- get_optimizer(config)[source]¶
Returns a dictionary with global and local optimizers with string keys ‘global’ and ‘local’ respectively.
- get_parameter_exchanger(config)[source]¶
Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.
- Parameters:
config (Config) – The config from server.
- Returns:
Used to exchange parameters between server and client.
- Return type:
- set_optimizer(config)[source]¶
Method called in the the setup_client method to set optimizer attribute returned by used-defined get_optimizer. In the simplest case, get_optimizer returns an optimizer. For more advanced use cases where a dictionary of string and optimizer are returned (ie APFL), the user must override this method.
- Parameters:
config (Config) – The config from the server.
- Return type:
- train_step(input, target)[source]¶
Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (ie backprop on a single batch of data). Assumes self.model is in train mode already.
- Parameters:
input (TorchInputType) – The input to be fed into the model.
target (TorchTargetType) – The target corresponding to the input.
- Returns:
- The losses object from the train step along with
a dictionary of any predictions produced by the model.
- Return type:
tuple[TrainingLosses, TorchPredType]