fl4health.clients.constrained_fenda_client module¶
- class ConstrainedFendaClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, loss_container=None)[source]¶
Bases:
FendaClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, loss_container=None)[source]¶
This class extends the functionality of FENDA training to include various kinds of constraints applied during the client-side training of FENDA models.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to
LossMeterType.AVERAGE
.checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses
tqdm
. Defaults to Falseclient_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
loss_container (ConstrainedFendaLossContainer | None, optional) – Configuration that determines which losses will be applied during FENDA training. Defaults to None.
- compute_evaluation_loss(preds, features, target)[source]¶
Computes evaluation loss given predictions of the model and ground truth data. Optionally computes additional loss components such as
cosine_similarity_loss
,contrastive_loss
andperfcl_loss
based on client attributes set from server config.- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
An instance of
EvaluationLosses
containing checkpoint loss and additional losses indexed by name. Additional losses may includecosine_similarity_loss
,contrastive_loss
andperfcl_loss
.- Return type:
- compute_loss_and_additional_losses(preds, features, target)[source]¶
Computes the loss and any additional losses given predictions of the model and ground truth data. For FENDA, the loss is the total loss and the additional losses are the loss, total loss and, based on client attributes set from server config, cosine similarity loss, contrastive loss and perfcl losses.
- Parameters:
- Returns:
A tuple with:
The tensor for the total loss
A dictionary with
loss
,total_loss
and, based on client attributes set from server config, alsocos_sim_loss
,contrastive_loss
,contrastive_loss_minimize
andcontrastive_loss_minimize
keys and their respective calculated values.
- Return type:
- get_parameter_exchanger(config)[source]¶
Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.
- Parameters:
config (Config) – The config from server.
- Returns:
Used to exchange parameters between server and client.
- Return type:
- predict(input)[source]¶
Computes the prediction(s) and features of the model(s) given the input.
- Parameters:
input (TorchInputType) – Inputs to be fed into the model.
TorchInputType
is simply an alias for the union oftorch.Tensor
anddict[str, torch.Tensor]
.- Returns:
A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics.
- Return type:
- update_after_train(local_steps, loss_dict, config)[source]¶
This function is called after client-side training concludes. If a contrastive or PerFCL loss function has been defined, it is used to save the local and global feature extraction weights/modules to be used in the next round of client-side training.
- update_before_train(current_server_round)[source]¶
This function is called prior to the start of client-side training, but after the server parameters have be received and injected into the model. If a PerFCL loss function has been defined, it is used to save the aggregated global feature extractor weights/module representing the initial state of this module BEFORE this iteration of client-side training but AFTER server-side aggregation.