fl4health.clients.ditto_client module

class DittoClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: AdaptiveDriftConstraintClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

This client implements the Ditto algorithm from Ditto: Fair and Robust Federated Learning Through Personalization. The idea is that we want to train personalized versions of the global model for each client. So we simultaneously train a global model that is aggregated on the server-side and use those weights to also constrain the training of a local model. The constraint for this local model is identical to the FedProx loss.

NOTE: lambda, the drift loss weight, is initially set and potentially adapted by the server akin to the heuristic suggested in the original FedProx paper. Adaptation is optional and can be disabled in the corresponding strategy used by the server

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

compute_evaluation_loss(preds, features, target)[source]

Computes evaluation loss given predictions (and potentially features) of the model and ground truth data. For Ditto, we use the vanilla loss for the local model in checkpointing. However, during validation we also compute the global model vanilla loss.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

An instance of EvaluationLosses containing checkpoint loss and additional losses indexed by name.

Return type:

EvaluationLosses

compute_loss_and_additional_losses(preds, features, target)[source]

Computes the local model loss and the global Ditto model loss (stored in additional losses) for reporting and training of the global model

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name.

  • features (TorchFeatureType) – Feature(s) of the model(s) indexed by name.

  • target (TorchTargetType) – Ground truth data to evaluate predictions against.

Returns:

A tuple with:

  • The tensor for the model loss

  • A dictionary with local_loss, global_loss as additionally reported loss values.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

compute_training_loss(preds, features, target)[source]

Computes training losses given predictions of the global and local models and ground truth data. For the local model we add to the vanilla loss function by including Ditto penalty loss which is the l2 inner product between the initial global model weights and weights of the local model. This is stored in backward The loss to optimize the global model is stored in the additional losses dictionary under “global_loss”

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

An instance of TrainingLosses containing backward loss and additional losses indexed by name. Additional losses includes each loss component and the global model loss tensor.

Return type:

TrainingLosses

get_global_model(config)[source]

Returns the global model to be used during Ditto training and as a constraint for the local model.

The global model should be the same architecture as the local model so we reuse the get_model call. We explicitly send the model to the desired device. This is idempotent.

Parameters:

config (Config) – The config from the server.

Returns:

The PyTorch model serving as the global model for Ditto

Return type:

nn.Module

get_optimizer(config)[source]

Returns a dictionary with global and local optimizers with string keys “global” and “local” respectively.

Parameters:

config (Config) – The config from the server.

Return type:

dict[str, Optimizer]

get_parameters(config)[source]

For Ditto, we transfer the GLOBAL model weights to the server to be aggregated. The local model weights stay with the client.

Parameters:

config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Returns:

GLOBAL model weights to be sent to the server for aggregation

Return type:

NDArrays

initialize_all_model_weights(parameters, config)[source]

If this is the first time we’re initializing the model weights, we initialize both the global and the local weights together.

Parameters:
  • parameters (NDArrays) – Model parameters to be injected into the client model

  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Return type:

None

predict(input)[source]

Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary

Parameters:

input (TorchInputType) – Inputs to be fed into both models.

Returns:

A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. For Ditto, we only need the predictions, so the second dictionary is simply empty.

Return type:

tuple[TorchPredType, TorchFeatureType]

Raises:

ValueError – Occurs when something other than a tensor or dict of tensors is returned by the model forward.

set_initial_global_tensors()[source]

Saving the initial GLOBAL MODEL weights and detaching them so that we don’t compute gradients with respect to the tensors. These are used to form the Ditto local update penalty term.

Return type:

None

set_optimizer(config)[source]

Ditto requires an optimizer for the global model and one for the local model. This function simply ensures that the optimizers setup by the user have the proper keys and that there are two optimizers.

Parameters:

config (Config) – The config from the server.

Return type:

None

set_parameters(parameters, config, fitting_round)[source]

Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are unpacked for the clients to use in training. The parameters being passed are to be routed to the global model. In the first fitting round, we assume the both the global and local models are being initialized and use the FullParameterExchanger() to initialize both sets of model weights to the same parameters.

Parameters:
  • parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model (global model for all but the first step of Ditto). These should also include a penalty weight from the server that needs to be unpacked.

  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

  • fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. If the current federated learning round is the very first fitting round, then we initialize both the global and local Ditto models with weights sent from the server.

Return type:

None

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. In this class, this function simply adds the additional step of setting up the global model.

Parameters:

config (Config) – The config from the server.

Return type:

None

train_step(input, target)[source]

Mechanics of training loop follow from original Ditto implementation: https://github.com/litian96/ditto

As in the implementation there, steps of the global and local models are done in tandem and for the same number of steps.

Parameters:
  • input (TorchInputType) – input tensor to be run through both the global and local models. Here, TorchInputType is simply an alias for the union of torch.Tensor and dict[str, torch.Tensor].

  • target (TorchTargetType) – target tensor to be used to compute a loss given each models outputs.

Returns:

Returns relevant loss values from both the global and local model optimization steps. The prediction dictionary contains predictions indexed a “global” and “local” corresponding to predictions from the global and local Ditto models for metric evaluations.

Return type:

tuple[TrainingLosses, TorchPredType]

update_before_train(current_server_round)[source]

Procedures that should occur before proceeding with the training loops for the models. In this case, we save the global models parameters to be used in constraining training of the local model.

Parameters:

current_server_round (int) – Indicates which server round we are currently executing.

Return type:

None

validate(include_losses_in_metrics=False)[source]

Validate the current model on the entire validation dataset.

Returns:

The validation loss and a dictionary of metrics from validation.

Return type:

tuple[float, dict[str, Scalar]]