fl4health.clients.mkmmd_clients.ditto_mkmmd_client module

class DittoMkMmdClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, mkmmd_loss_weight=10.0, feature_extraction_layers=None, feature_l2_norm_weight=0.0, beta_global_update_interval=20, num_accumulating_batches=None)[source]

Bases: DittoClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, mkmmd_loss_weight=10.0, feature_extraction_layers=None, feature_l2_norm_weight=0.0, beta_global_update_interval=20, num_accumulating_batches=None)[source]

This client implements the MK-MMD loss function in the Ditto framework. The MK-MMD loss is a measure of the distance between the distributions of the features of the local model and initial global model of each round. The MK-MMD loss is added to the local loss to penalize the local model for drifting away from the global model.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

  • mkmmd_loss_weight (float, optional) – weight applied to the MK-MMD loss. Defaults to 10.0.

  • feature_extraction_layers (Sequence[str] | None, optional) – List of layers from which to extract and flatten features. Defaults to None.

  • feature_l2_norm_weight (float, optional) – weight applied to the L2 norm of the features. Defaults to 0.0.

  • beta_global_update_interval (int, optional) – interval at which to update the betas for the MK-MMD loss. If set to above 0, the betas will be updated based on whole distribution of latent features of data with the given update interval. If set to 0, the betas will not be updated. If set to -1, the betas will be updated after each individual batch based on only that individual batch. Defaults to 20.

  • num_accumulating_batches (int, optional) – Number of batches to accumulate features to approximate the whole distribution of the latent features for updating beta of the MK-MMD loss. This parameter is only used if beta_global_update_interval is set to larger than 0. Defaults to None.

compute_loss_and_additional_losses(preds, features, target)[source]

Computes the loss and any additional losses given predictions of the model and ground truth data.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name.

  • features (TorchFeatureType) – Feature(s) of the model(s) indexed by name.

  • target (TorchTargetType) – Ground truth data to evaluate predictions against.

Returns:

A tuple with:
  • The tensor for the loss

  • A dictionary of additional losses with their names and values, or None if

    there are no additional losses.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

compute_training_loss(preds, features, target)[source]

Computes training losses given predictions of the global and local models and ground truth data. For the local model we add to the vanilla loss function by including Ditto penalty loss which is the l2 inner product between the initial global model weights and weights of the local model. This is stored in backward The loss to optimize the global model is stored in the additional losses dictionary under “global_loss”

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

an instance of TrainingLosses containing backward loss and

additional losses indexed by name. Additional losses includes each loss component and the global model loss tensor.

Return type:

TrainingLosses

predict(input)[source]
Return type:

tuple[dict[str, Tensor], dict[str, Tensor]]

Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary

Args:
input (TorchInputType): Inputs to be fed into the model. If input is

of type dict[str, torch.Tensor], it is assumed that the keys of input match the names of the keyword arguments of self.model. forward().

Returns:
tuple[TorchPredType, TorchFeatureType]: A tuple in which the

first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute all the losses. All predictions included in dictionary will by default be used to compute metrics separately.

Raises:
  • TypeError – Occurs when something other than a tensor or dict of tensors is passed in to the model’s

  • forward method.

  • ValueError – Occurs when something other than a tensor or dict of tensors is returned by the model

  • forward.

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. In this class, this function simply adds the additional step of setting up the global model.

Parameters:

config (Config) – The config from the server.

Return type:

None

update_after_step(step, current_round=None)[source]

Hook method called after local train step on client. step is an integer that represents the local training step that was most recently completed. For example, used by the APFL method to update the alpha value after a training a step. Also used by the MOON, FENDA and Ditto to update optimized beta value for MK-MMD loss after n steps.

Parameters:
  • step (int) – The step number in local training that was most recently completed. Resets only at the end of the round.

  • current_round (int | None, optional) – The current FL server round

Return type:

None

update_before_train(current_server_round)[source]

Procedures that should occur before proceeding with the training loops for the models. In this case, we save the global models parameters to be used in constraining training of the local model.

Parameters:

current_server_round (int) – Indicates which server round we are currently executing.

Return type:

None

update_buffers(local_model, initial_global_model)[source]

Update the feature buffer of the local and global features.

Parameters:
  • local_model (torch.nn.Module) – Local model to extract features from.

  • initial_global_model (torch.nn.Module) – Initial global model to extract features from.

Returns:

A tuple containing the extracted features using the local and initial global models.

Return type:

tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]