fl4health.clients.mr_mtl_client module

class MrMtlClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: AdaptiveDriftConstraintClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

This client implements the MR-MTL algorithm from MR-MTL: On Privacy and Personalization in Cross-Silo Federated Learning. The idea is that we want to train personalized versions of the global model for each client. However, instead of using a separate solver for the global model, as in Ditto, we update the initial global model with aggregated local models on the server-side and use those weights to also constrain the training of a local model. The constraint for this local model is identical to the FedProx loss. The key difference is that the local model is never replaced with aggregated weights. It is always local.

NOTE: lambda, the drift loss weight, is initially set and potentially adapted by the server akin to the heuristic suggested in the original FedProx paper. Adaptation is optional and can be disabled in the corresponding strategy used by the server

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

compute_training_loss(preds, features, target)[source]

Computes training losses given predictions of the modes and ground truth data. We add to vanilla loss function by including Mean Regularized (MR) penalty loss which is the l2 inner product between the initial global model weights and weights of the current model.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

an instance of TrainingLosses containing backward loss and additional losses indexed by

name. Additional losses includes each loss component of the total loss.

Return type:

TrainingLosses

set_parameters(parameters, config, fitting_round)[source]

The parameters being passed are to be routed to the initial global model to be used in a penalty term in training the local model. Despite the usual FL setup, we actually never pass the aggregated model to the LOCAL model. Instead, we use the aggregated model to form the MR-MTL penalty term.

NOTE; In MR-MTL, unlike Ditto, the local model weights are not synced across clients to the initial global model, even in the FIRST ROUND.

Parameters:
  • parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model. It will also contain a penalty weight from the server at each round (possibly adapted)

  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

  • fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. Not used here.

Return type:

None

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True.

Parameters:

config (Config) – The config from the server.

Return type:

None

update_before_train(current_server_round)[source]

Hook method called before training with the number of current server rounds performed. NOTE: This method is called immediately AFTER the aggregated parameters are received from the server. For example, used by MOON and FENDA to save global modules after aggregation.

Parameters:

current_server_round (int) – The number of current server round.

Return type:

None

validate(include_losses_in_metrics=False)[source]

Validate the current model on the entire validation dataset.

Returns:

The validation loss and a dictionary of metrics from validation.

Return type:

tuple[float, dict[str, Scalar]]