fl4health.clients.gpfl_client module

class GpflClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, lam=0.01, mu=0.01)[source]

Bases: BasicClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, lam=0.01, mu=0.01)[source]

This client is used to perform client-side training associated with the GPFL method described in https://arxiv.org/abs/2308.10279.

In this approach, the client’s model is sequentially split into a feature extractor and a head module. The client also has two extra modules that are trained alongside the main model: a CoV (Conditional Value), and a GCE (Global Category Embedding) module. These sub-modules are trained in the client and shared with the server alongside the feature extractor. In simple words, CoV takes in the output of the feature extractor (feature_tensor) and maps it into two feature tensors (personal f_p and general f_g) computed through affine mapping. f_p`is fed into the head module for classification, while `f_g is used to train the GCE module. GCE is a lookup table that stores a global representative embedding for each class. The GCE module is used to generate two conditional tensors: global_conditional_input and personalized_conditional_input referred to in the paper as g and p_i, respectively. These conditional inputs are then used in the CoV module. All the components are trained simultaneously via a combined loss.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

  • lam (float, optional) – A hyperparameter that controls the weight of the GCE magnitude-level global loss. Defaults to 0.01.

  • mu (float, optional) – A hyperparameter that acts as the weight of the L2 regularization on the GCE and CoV modules. This value is used as the optimizers’ weight decay parameter. This can be set in get_optimizer function defined by the client user, or if it is not set by the user, it will be set in set_optimizer method. Defaults to 0.01.

calculate_class_sample_proportions()[source]

This method is used to compute the class sample proportions based on the training data. It computes the proportion of samples for each class in the training dataset.

Returns:

A tensor containing the proportion of samples for each class.

Return type:

torch.Tensor

compute_conditional_inputs()[source]

Calculates the conditional inputs (p_i and g) for the CoV module based on the new GCE from the server. The self.global_conditional_input and self.personalized_conditional_input tensors are computed based on a frozen GCE model and the sample per class tensor. These tensors are fixed in each client round, and are recomputed when a new GCE module is shared by the server in every client round.

Return type:

None

compute_magnitude_level_loss(global_features, target)[source]

Computes magnitude level loss corresponds to \(\mathcal{L}_i^{ ext{mlg}}\) in the paper.

Parameters:
  • global_features (torch.Tensor) – global features computed in this client.

  • target (TorchTargetType) – Either a tensor of class indices or one-hot encoded tensors.

Returns:

L2 norm loss between the global features and the frozen GCE’s global features.

Return type:

torch.Tensor

compute_training_loss(preds, features, target)[source]

Computes the combined training loss given predictions, global features of the model, and ground truth data. GPFL loss is a combined loss and is defined as prediction_loss + gce_softmax_loss + magnitude_level_loss.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

An instance of TrainingLosses containing backward loss and additional losses indexed by name.

Return type:

TrainingLosses

get_optimizer(config)[source]

Returns a dictionary with model, gce, and cov optimizers with string keys “model”, “gce”, and “cov” respectively.

Parameters:

config (Config) – The config from the server.

Returns:

A dictionary of optimizers defined by the user

Return type:

dict[str, Optimizer]

get_parameter_exchanger(config)[source]

GPFL client uses a fixed layer exchanger to exchange layers in three sub-modules. Sub-modules to be exchanged are defined in the GpflModel class.

Parameters:

config (Config) – Config from the server..

Returns:

FixedLayerExchanger used to exchange a set of fixed and specific layers.

Return type:

ParameterExchanger

set_optimizer(config)[source]

This function simply ensures that the optimizers setup by the user have the proper keys and that there are three optimizers.

Parameters:

config (Config) – The config from the server.

Return type:

None

setup_client(config)[source]

In addition to dataloaders, optimizers, parameter exchangers, a few GPFL specific parameters are set up in this method. This includes the number of classes, feature dimension, and the sample per class tensor. The global and personalized conditional inputs are also initialized.

Parameters:

config (Config) – The config from the server.

Return type:

None

train_step(input, target)[source]

Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (i.e. backprop on a single batch of data). Assumes self.model is in train mode already.

Parameters:
  • input (TorchInputType) – The input to be fed into the model.

  • target (TorchTargetType) – The target corresponding to the input.

Returns:

The losses object from the train step along with a dictionary of any predictions produced by the model.

Return type:

tuple[TrainingLosses, TorchPredType]

transform_input(input)[source]

Extend the input dictionary with global_conditional_input and personalized_conditional_input tensors. This let’s use provide these additional tensor to the GPFL model .

Parameters:

input (TorchInputType) – Input tensor.

Returns:

Transformed input tensor.

Return type:

TorchInputType

update_before_train(current_server_round)[source]

Updates the frozen GCE model and computes the conditional inputs before training starts.

Parameters:

current_server_round (int) – The number of current server round.

Return type:

None

val_step(input, target)[source]

Before performing validation, we need to transform the input and attach the global and personalized conditional tensors to the input.

Parameters:
  • input (TorchInputType) – Input based on the training data.

  • target (TorchTargetType) – The target corresponding to the input..

Returns:

tuple[EvaluationLosses, TorchPredType]: The losses object from the val step along with a dictionary of the predictions produced by the model.

Return type:

tuple[EvaluationLosses, TorchPredType]