fl4health.clients.fenda_ditto_client module

class FendaDittoClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, freeze_global_feature_extractor=False)[source]

Bases: DittoClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, freeze_global_feature_extractor=False)[source]

This client implements a combination of the Ditto algorithm from Ditto: Fair and Robust Federated Learning Through Personalization with FENDA-FL models. In this implementation, the global Ditto model consists of a feature extractor and classification head, where the feature extractor architecture is identical to that of the global and local feature extractors of the FENDA model being trained. The idea is that we want to train a local FENDA model along with the global model for each client. We simultaneously train a global model that is aggregated on the server-side and use those weights to also constrain the training of a local FENDA model. At the beginning of each server round, the feature extractor from globally aggregated model is injected into the global feature extractor of the FENDA model.

There are two distinct modes of operation:

If freeze_global_feature_extractor is True. The global Ditto model feature extractor SETS AND FREEZES weights of global FENDA feature extractor. The local components of the FENDA model are trained and an additional drift loss is computed between the local and global feature extractors of the FENDA model.

If freeze_global_feature_extractor is False. The global Ditto model feature extractor INITIALIZES weights of the FENDA model’s global feature extractor, both local and global components of FENDA are trained and a drift loss is calculated between Ditto global feature extractor and FENDA global feature extractor.

The constraint for the FENDA model feature extractors discussed above uses a weight drift loss on its feature extraction modules.

NOTE: Unlike FENDA, the global feature extractor of the FENDA model is NOT exchanged with the server. Rather, the global Ditto model is exchanged and injected at each round into the global feature extractor. If the global feature extractor is frozen, then only the local components of the FENDA network are trained.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

  • freeze_global_feature_extractor (bool, optional) – Determines whether we freeze the FENDA global feature extractor during training. If freeze_global_feature_extractor is False, both the global and the local feature extractor in the local FENDA model will be trained. Otherwise, the global feature extractor submodule is frozen. If freeze_global_feature_extractor is True, the Ditto loss will be calculated using the local FENDA feature extractor and the global model. Otherwise, the loss is calculated using the global FENDA feature extractor and the global model. Defaults to False.

compute_training_loss(preds, features, target)[source]

Computes training losses given predictions of the global and local models and ground truth data. For the local model, we add to the vanilla loss function by including a Ditto penalty loss. This penalty is the L2 inner product between the initial global model feature extractor weights and the feature extractor weights of the local model. If the global feature extractor is not frozen, the penalty is computed using the global feature extractor of the local model. If it is frozen, the penalty is computed using the local feature extractor of the local model. This allows for flexibility in training scenarios where the feature extractors may differ between the global and local models. The penalty is stored in “backward”. The loss to optimize the global model is stored in the additional losses dictionary under “global_loss”.

Parameters:
  • preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in the dictionary will be used to compute metrics.

  • features (dict[str, torch.Tensor]) – Feature(s) of the model(s) indexed by name.

  • target (torch.Tensor) – Ground truth data to evaluate predictions against.

Returns:

An instance of TrainingLosses containing the backward loss and

additional losses indexed by name. Additional losses include each loss component and the global model loss tensor.

Return type:

TrainingLosses

get_global_model(config)[source]

User defined method that returns a Global Sequential Model that is compatible with the local FENDA model.

Parameters:

config (Config) – The config from the server.

Returns:

The global (Ditto) model.

Return type:

SequentiallySplitModel

Raises:

NotImplementedError – To be defined in child class.

get_model(config)[source]

User defined method that returns FENDA model.

Parameters:

config (Config) – The config from the server.

Returns:

The client FENDA model.

Return type:

FendaModel

Raises:

NotImplementedError – To be defined in child class.

get_parameters(config)[source]

For FendaDitto, we transfer the GLOBAL Ditto model weights to the server to be aggregated. The local FENDA model weights stay with the client. The local FENDA model has a different architecture than the GLOBAL model. So if the client is being asked for initialization parameters, we just send the GLOBAL model to sync all GLOBAL models across clients AND the local FENDA model’s global feature extractor.

Parameters:

config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Returns:

GLOBAL model weights to be sent to the server for aggregation

Return type:

NDArrays

predict(input)[source]

Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary

Parameters:

input (TorchInputType) – Inputs to be fed into both models.

Returns:

A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. For Ditto+FENDA, we only need the predictions, so the second dictionary is simply empty.

Return type:

tuple[TorchPredType, TorchFeatureType]

Raises:
  • ValueError – Occurs when something other than a tensor or dict of tensors is returned by the model

  • forward.

set_initial_global_tensors()[source]
Return type:

None

set_parameters(parameters, config, fitting_round)[source]

The parameters being passed are to be routed to the global (ditto) model and copied to the global feature extractor of the local FENDA model and saved as the initial global model tensors to be used in a penalty term in training the local model. We assume the both the global and local models are being initialized and use a FullParameterExchanger() to set the model weights for the global model, the global model feature extractor weights will be then copied to the global feature extractor of local FENDA model. :type parameters: List[ndarray[Any, dtype[Any]]] :param parameters: Parameters have information about model state to be added to the relevant client

model

Parameters:
  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

  • fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. A full parameter exchanger is only used if the current federated learning round is the very first fitting round.

Return type:

None

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. This function simply straps on the compatibility of the models.

Parameters:

config (Config) – The config from the server.

Return type:

None

update_before_train(current_server_round)[source]

Procedures that should occur before proceeding with the training loops for the models. In this case, we save the global models parameters to be used in constraining training of the local model.

Parameters:

current_server_round (int) – Indicates which server round we are currently executing.

Return type:

None