fl4health.clients.moon_client module¶
- class MoonClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, temperature=0.5, contrastive_weight=1.0, len_old_models_buffer=1)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, temperature=0.5, contrastive_weight=1.0, len_old_models_buffer=1)[source]¶
This client implements the MOON algorithm from Model-Contrastive Federated Learning. The key idea of MOON is to enforce similarity between representations from the global and current local model through a contrastive loss to constrain the local training of individual parties in the non-IID setting.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to
LossMeterType.AVERAGE
.checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses
tqdm
. Defaults to Falseclient_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
temperature (float, optional) – Temperature used in the calculation of the contrastive loss. Defaults to 0.5.
contrastive_weight (float, optional) – Weight placed on the contrastive loss function. Referred to as mu in the original paper. Defaults to 1.0.
len_old_models_buffer (int, optional) – Number of old models to be stored for computation in the contrastive learning loss function. Defaults to 1.
- compute_evaluation_loss(preds, features, target)[source]¶
Computes evaluation loss given predictions and features of the model and ground truth data. Loss includes base loss plus a model contrastive loss.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
An instance of
EvaluationLosses
containing checkpoint loss and additional losses indexed by name.- Return type:
- compute_loss_and_additional_losses(preds, features, target)[source]¶
Computes the loss and any additional losses given predictions of the model and ground truth data. For MOON, the loss is the total loss (criterion and weighted contrastive loss) and the additional losses are the loss, (unweighted) contrastive loss, and total loss.
- Parameters:
- Returns:
A tuple with:
The tensor for the total loss
A dictionary with loss, contrastive_loss and total_loss keys and their calculated values.
- Return type:
- compute_training_loss(preds, features, target)[source]¶
Computes training loss given predictions and features of the model and ground truth data. Loss includes base loss plus a model contrastive loss.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
An instance of
TrainingLosses
containing backward loss and additional losses indexed by name.- Return type:
- predict(input)[source]¶
Computes the prediction(s) and features of the model(s) given the input. This function also produces the necessary features from the
global_model
(aggregated model from server) andold_models
(previous client-side optimized models) in order to be able to compute the appropriate contrastive loss.- Parameters:
input (TorchInputType) – Inputs to be fed into the model.
TorchInputType
is simply an alias for the union oftorch.Tensor
anddict[str, torch.Tensor]
. Here, the MOON models require input to simply be of typetorch.Tensor
- Returns:
A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics.
- Return type:
- update_after_train(local_steps, loss_dict, config)[source]¶
This function is called immediately after client-side training has completed. This function saves the final trained model to the list of old models to be used in subsequent server rounds
- update_before_train(current_server_round)[source]¶
This function is called before training, immediately after injecting the aggregated server weights into the client model. We clone and free the current model to preserve the aggregated server weights state (i.e. the initial model before training starts.)