fl4health.clients.perfcl_client module¶
- class PerFclClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5, global_feature_contrastive_loss_weight=1.0, local_feature_contrastive_loss_weight=1.0)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5, global_feature_contrastive_loss_weight=1.0, local_feature_contrastive_loss_weight=1.0)[source]¶
This client is used to perform client-side training associated with the PerFCL method derived in https://www.sciencedirect.com/science/article/pii/S0031320323002078. The approach attempts to manipulate the training dynamics of a parallel weight split model with a global feature extractor, that is aggregated on the server-side with FedAvg and a local feature extractor that is only locally trained. This method is related to FENDA, but with additional losses on the latent spaces of the local and global feature extractors.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False
client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
global_feature_loss_temperature (float, optional) – Temperature to be used in the contrastive loss associated with constraining the global feature extractor in the PerFCL loss. Defaults to 0.5.
local_feature_loss_temperature (float, optional) – Temperature to be used in the contrastive loss associated with constraining the local feature extractor in the PerFCL loss. Defaults to 0.5.
global_feature_contrastive_loss_weight (float, optional) – Weight on the contrastive loss value associated with the global feature extractor. REFERRED TO AS MU in the original paper. Defaults to 1.0.
local_feature_contrastive_loss_weight (float, optional) – Weight on the contrastive loss value associated with the local feature extractor. REFERRED TO AS GAMMA in the original paper. Defaults to 1.0.
- compute_evaluation_loss(preds, features, target)[source]¶
Computes evaluation loss given predictions of the model and ground truth data. Also computes additional loss components associated with the PerFCL loss function.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
- an instance of EvaluationLosses containing checkpoint loss and additional losses
indexed by name.
- Return type:
- compute_loss_and_additional_losses(preds, features, target)[source]¶
Computes the loss and any additional losses given predictions of the model and ground truth data. For PerFCL, the total loss is the standard criterion loss provided by the user and the PerFCL contrastive losses aimed at manipulating the local and global feature extractor latent spaces.
- Parameters:
- Returns:
The tensor for the total loss
- A dictionary with loss, total_loss, global_feature_contrastive_loss, and
local_feature_contrastive_loss representing the various and relevant pieces of the loss calculations
- Return type:
- get_parameter_exchanger(config)[source]¶
Sets the parameter exchanger to be used by the clients to send parameters to and receive them from the server For PerFCL clients, a FixedLayerExchanger is used by default. We also required that the model being exchanged is of the PerFclModel type to ensure that the appropriate layers are exchanged.
- Parameters:
config (Config) – Configuration provided by the server.
- Returns:
- FixedLayerExchanger meant to only exchange a subset of model layers with the server
for aggregation.
- Return type:
- predict(input)[source]¶
Computes the prediction(s) and features of the model(s) given the input.
- Parameters:
input (TorchInputType) – Inputs to be fed into the model. TorchInputType is simply an alias
dict[str (for the union of torch.Tensor and)
torch.Tensor].
- Returns:
A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics.
- Return type:
- update_after_train(local_steps, loss_dict, config)[source]¶
This function is called after client-side training concludes. In this case, it is used to save the local and global feature extraction weights/modules to be used in the next round of client-side training.
- update_before_train(current_server_round)[source]¶
This function is called prior to the start of client-side training, but after the server parameters have be received and injected into the model. In this case, it is used to save the aggregated global feature extractor weights/module representing the initial state of this module BEFORE this iteration of client-side training but AFTER server-side aggregation.