fl4health.clients.gpfl_client module¶
- class GpflClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, lam=0.01, mu=0.01)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, lam=0.01, mu=0.01)[source]¶
This client is used to perform client-side training associated with the GPFL method described in https://arxiv.org/abs/2308.10279.
In this approach, the client’s model is sequentially split into a feature extractor and a head module. The client also has two extra modules that are trained alongside the main model: a CoV (Conditional Value), and a GCE (Global Category Embedding) module. These sub-modules are trained in the client and shared with the server alongside the feature extractor. In simple words, CoV takes in the output of the feature extractor (feature_tensor) and maps it into two feature tensors (personal f_p and general f_g) computed through affine mapping. f_p`is fed into the head module for classification, while `f_g is used to train the GCE module. GCE is a lookup table that stores a global representative embedding for each class. The GCE module is used to generate two conditional tensors:
global_conditional_input
andpersonalized_conditional_input
referred to in the paper as g and p_i, respectively. These conditional inputs are then used in the CoV module. All the components are trained simultaneously via a combined loss.- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to
LossMeterType.AVERAGE
.checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses
tqdm
. Defaults to Falseclient_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
lam (float, optional) – A hyperparameter that controls the weight of the GCE magnitude-level global loss. Defaults to 0.01.
mu (float, optional) – A hyperparameter that acts as the weight of the L2 regularization on the GCE and CoV modules. This value is used as the optimizers’ weight decay parameter. This can be set in
get_optimizer
function defined by the client user, or if it is not set by the user, it will be set inset_optimizer
method. Defaults to 0.01.
- calculate_class_sample_proportions()[source]¶
This method is used to compute the class sample proportions based on the training data. It computes the proportion of samples for each class in the training dataset.
- Returns:
A tensor containing the proportion of samples for each class.
- Return type:
torch.Tensor
- compute_conditional_inputs()[source]¶
Calculates the conditional inputs (p_i and g) for the CoV module based on the new GCE from the server. The
self.global_conditional_input
andself.personalized_conditional_input
tensors are computed based on a frozen GCE model and the sample per class tensor. These tensors are fixed in each client round, and are recomputed when a new GCE module is shared by the server in every client round.- Return type:
- compute_magnitude_level_loss(global_features, target)[source]¶
Computes magnitude level loss corresponds to \(\mathcal{L}_i^{ ext{mlg}}\) in the paper.
- Parameters:
global_features (torch.Tensor) – global features computed in this client.
target (TorchTargetType) – Either a tensor of class indices or one-hot encoded tensors.
- Returns:
L2 norm loss between the global features and the frozen GCE’s global features.
- Return type:
torch.Tensor
- compute_training_loss(preds, features, target)[source]¶
Computes the combined training loss given predictions, global features of the model, and ground truth data. GPFL loss is a combined loss and is defined as
prediction_loss + gce_softmax_loss + magnitude_level_loss
.- Parameters:
preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (TorchTargetType): Ground truth data to evaluate predictions against.
- Returns:
An instance of
TrainingLosses
containing backward loss and additional losses indexed by name.- Return type:
- get_optimizer(config)[source]¶
Returns a dictionary with model, gce, and cov optimizers with string keys “model”, “gce”, and “cov” respectively.
- get_parameter_exchanger(config)[source]¶
GPFL client uses a fixed layer exchanger to exchange layers in three sub-modules. Sub-modules to be exchanged are defined in the
GpflModel
class.- Parameters:
config (Config) – Config from the server..
- Returns:
FixedLayerExchanger used to exchange a set of fixed and specific layers.
- Return type:
- set_optimizer(config)[source]¶
This function simply ensures that the optimizers setup by the user have the proper keys and that there are three optimizers.
- Parameters:
config (Config) – The config from the server.
- Return type:
- setup_client(config)[source]¶
In addition to dataloaders, optimizers, parameter exchangers, a few GPFL specific parameters are set up in this method. This includes the number of classes, feature dimension, and the sample per class tensor. The global and personalized conditional inputs are also initialized.
- Parameters:
config (Config) – The config from the server.
- Return type:
- train_step(input, target)[source]¶
Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (i.e. backprop on a single batch of data). Assumes
self.model
is in train mode already.- Parameters:
input (TorchInputType) – The input to be fed into the model.
target (TorchTargetType) – The target corresponding to the input.
- Returns:
The losses object from the train step along with a dictionary of any predictions produced by the model.
- Return type:
tuple[TrainingLosses, TorchPredType]
- transform_input(input)[source]¶
Extend the input dictionary with
global_conditional_input
andpersonalized_conditional_input
tensors. This let’s use provide these additional tensor to the GPFL model .- Parameters:
input (TorchInputType) – Input tensor.
- Returns:
Transformed input tensor.
- Return type:
TorchInputType
- update_before_train(current_server_round)[source]¶
Updates the frozen GCE model and computes the conditional inputs before training starts.
- val_step(input, target)[source]¶
Before performing validation, we need to transform the input and attach the global and personalized conditional tensors to the input.
- Parameters:
input (TorchInputType) – Input based on the training data.
target (TorchTargetType) – The target corresponding to the input..
- Returns:
tuple[EvaluationLosses, TorchPredType]: The losses object from the val step along with a dictionary of the predictions produced by the model.
- Return type:
tuple[EvaluationLosses, TorchPredType]