fl4health.clients.moon_client module¶
- class MoonClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, temperature=0.5, contrastive_weight=1.0, len_old_models_buffer=1)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, temperature=0.5, contrastive_weight=1.0, len_old_models_buffer=1)[source]¶
This client implements the MOON algorithm from Model-Contrastive Federated Learning. The key idea of MOON is to enforce similarity between representations from the global and current local model through a contrastive loss to constrain the local training of individual parties in the non-IID setting.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False
client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
temperature (float, optional) – Temperature used in the calculation of the contrastive loss. Defaults to 0.5.
contrastive_weight (float, optional) – Weight placed on the contrastive loss function. Referred to as mu in the original paper. Defaults to 1.0.
len_old_models_buffer (int, optional) – Number of old models to be stored for computation in the contrastive learning loss function. Defaults to 1.
- compute_evaluation_loss(preds, features, target)[source]¶
Computes evaluation loss given predictions and features of the model and ground truth data. Loss includes base loss plus a model contrastive loss.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
- an instance of EvaluationLosses containing checkpoint loss and
additional losses indexed by name.
- Return type:
- compute_loss_and_additional_losses(preds, features, target)[source]¶
Computes the loss and any additional losses given predictions of the model and ground truth data. For MOON, the loss is the total loss (criterion and weighted contrastive loss) and the additional losses are the loss, (unweighted) contrastive loss, and total loss.
- Parameters:
- Returns:
The tensor for the total loss
A dictionary with loss, contrastive_loss and total_loss keys and their calculated values.
- Return type:
- compute_training_loss(preds, features, target)[source]¶
Computes training loss given predictions and features of the model and ground truth data. Loss includes base loss plus a model contrastive loss.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
an instance of TrainingLosses containing backward loss and additional losses indexed by name.
- Return type:
- predict(input)[source]¶
Computes the prediction(s) and features of the model(s) given the input. This function also produces the necessary features from the global_model (aggregated model from server) and old_models (previous client-side optimized models) in order to be able to compute the appropriate contrastive loss.
- Parameters:
input (TorchInputType) – Inputs to be fed into the model. TorchInputType is simply an alias
dict[str (for the union of torch.Tensor and)
Here (torch.Tensor].)
to (the MOON models require input)
torch.Tensor (simply be of type)
- Returns:
A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics.
- Return type:
- update_after_train(local_steps, loss_dict, config)[source]¶
This function is called immediately after client-side training has completed. This function saves the final trained model to the list of old models to be used in subsequent server rounds
- update_before_train(current_server_round)[source]¶
This function is called before training, immediately after injecting the aggregated server weights into the client model. We clone and free the current model to preserve the aggregated server weights state (i.e. the initial model before training starts.)