from collections.abc import Sequence
from logging import WARNING
from pathlib import Path
import torch
from flwr.common.logger import log
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.basic_client import BasicClient, Config
from fl4health.losses.contrastive_loss import MoonContrastiveLoss
from fl4health.model_bases.sequential_split_models import SequentiallySplitModel
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.client import clone_and_freeze_model
from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses
from fl4health.utils.metrics import Metric
from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType
[docs]
class MoonClient(BasicClient):
[docs]
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None,
reporters: Sequence[BaseReporter] | None = None,
progress_bar: bool = False,
client_name: str | None = None,
temperature: float = 0.5,
contrastive_weight: float = 1.0,
len_old_models_buffer: int = 1,
) -> None:
"""
This client implements the MOON algorithm from Model-Contrastive Federated Learning. The key idea of MOON is
to enforce similarity between representations from the global and current local model through a contrastive
loss to constrain the local training of individual parties in the non-IID setting.
Args:
data_path (Path): path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model
device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or
'cuda'
loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over
each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle
both checkpointing and state saving. The module, and its underlying model and state checkpointing
components will determine when and how to do checkpointing during client-side training.
No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client
should send data to. Defaults to None.
progress_bar (bool, optional): Whether or not to display a progress bar during client training and
validation. Uses tqdm. Defaults to False
client_name (str | None, optional): An optional client name that uniquely identifies a client.
If not passed, a hash is randomly generated. Client state will use this as part of its state file
name. Defaults to None.
temperature (float, optional): Temperature used in the calculation of the contrastive loss.
Defaults to 0.5.
contrastive_weight (float, optional): Weight placed on the contrastive loss function. Referred to as mu
in the original paper. Defaults to 1.0.
len_old_models_buffer (int, optional): Number of old models to be stored for computation in the
contrastive learning loss function. Defaults to 1.
"""
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
)
self.temperature = temperature
self.contrastive_weight = contrastive_weight
if self.contrastive_weight == 0:
log(WARNING, "Contrastive loss weight is set to 0, thus Contrastive loss will not be computed.")
self.contrastive_loss_function = MoonContrastiveLoss(self.device, temperature=temperature)
# Saving previous local models and a global model at each communication round to compute contrastive loss
self.len_old_models_buffer = len_old_models_buffer
self.old_models_list: list[torch.nn.Module] = []
self.global_model: torch.nn.Module | None = None
[docs]
def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]:
"""
Computes the prediction(s) and features of the model(s) given the input. This function also produces the
necessary features from the global_model (aggregated model from server) and old_models (previous client-side
optimized models) in order to be able to compute the appropriate contrastive loss.
Args:
input (TorchInputType): Inputs to be fed into the model. TorchInputType is simply an alias
for the union of torch.Tensor and dict[str, torch.Tensor]. Here, the MOON models require input to
simply be of type torch.Tensor
Returns:
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple in which the first element
contains predictions indexed by name and the second element contains intermediate activations
index by name. Specifically the features of the model, features of the global model and features of
the old model are returned. All predictions included in dictionary will be used to compute metrics.
"""
assert isinstance(input, torch.Tensor)
preds, features = self.model(input)
assert "features" in features, "Model must produce a features dictionary with a 'features' key"
# If there are no models in the old_models_list, we don't compute the features for the contrastive loss
if len(self.old_models_list) > 0:
# Features from each of the old models are packed into a single tensor
old_features = torch.zeros(len(self.old_models_list), *features["features"].size()).to(self.device)
for i, old_model in enumerate(self.old_models_list):
_, old_model_features = old_model(input)
old_features[i] = old_model_features["features"]
features.update({"old_features": old_features})
if self.global_model is not None:
_, global_model_features = self.global_model(input)
features.update({"global_features": global_model_features["features"]})
return preds, features
[docs]
def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None:
"""
This function is called immediately after client-side training has completed. This function saves the final
trained model to the list of old models to be used in subsequent server rounds
Args:
local_steps (int): Number of local steps performed during training
loss_dict (dict[str, float]): Loss dictionary associated with training.
config (Config): The config from the server
"""
assert isinstance(self.model, SequentiallySplitModel)
# Save the parameters of the old LOCAL model
old_model = clone_and_freeze_model(self.model)
# Current model is appended to the back of the list
self.old_models_list.append(old_model)
# If the list is longer than desired, the element at the front of the list is removed.
if len(self.old_models_list) > self.len_old_models_buffer:
self.old_models_list.pop(0)
super().update_after_train(local_steps, loss_dict, config)
[docs]
def update_before_train(self, current_server_round: int) -> None:
"""
This function is called before training, immediately after injecting the aggregated server weights into the
client model. We clone and free the current model to preserve the aggregated server weights state (i.e. the
initial model before training starts.)
Args:
current_server_round (int): Current federated training round being executed.
"""
# Save the parameters of the global model
self.global_model = clone_and_freeze_model(self.model)
super().update_before_train(current_server_round)
[docs]
def compute_loss_and_additional_losses(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Computes the loss and any additional losses given predictions of the model and ground truth data.
For MOON, the loss is the total loss (criterion and weighted contrastive loss) and the additional losses are
the loss, (unweighted) contrastive loss, and total loss.
Args:
preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name.
features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.
target (torch.Tensor): Ground truth data to evaluate predictions against.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with:
- The tensor for the total loss
- A dictionary with `loss`, `contrastive_loss` and `total_loss` keys and their calculated values.
"""
loss = self.criterion(preds["prediction"], target)
total_loss = loss.clone()
additional_losses = {
"loss": loss,
}
if "old_features" in features and "global_features" in features:
contrastive_loss = self.contrastive_loss_function(
features["features"], features["global_features"].unsqueeze(0), features["old_features"]
)
total_loss += self.contrastive_weight * contrastive_loss
additional_losses["contrastive_loss"] = contrastive_loss
additional_losses["total_loss"] = total_loss
return total_loss, additional_losses
[docs]
def compute_training_loss(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> TrainingLosses:
"""
Computes training loss given predictions and features of the model and ground truth data. Loss includes
base loss plus a model contrastive loss.
Args:
preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name.
All predictions included in dictionary will be used to compute metrics.
features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.
target: (torch.Tensor): Ground truth data to evaluate predictions against.
Returns:
TrainingLosses: an instance of TrainingLosses containing backward loss and additional losses
indexed by name.
"""
# Check that the model is in training mode
assert self.model.training
# If there are no old local models in the list (first pass of MOON training), we just do basic loss
# calculations
if len(self.old_models_list) == 0:
total_loss, additional_losses = super().compute_loss_and_additional_losses(preds, features, target)
else:
total_loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target)
return TrainingLosses(backward=total_loss, additional_losses=additional_losses)
[docs]
def compute_evaluation_loss(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> EvaluationLosses:
"""
Computes evaluation loss given predictions and features of the model and ground truth data. Loss includes
base loss plus a model contrastive loss.
Args:
preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name.
All predictions included in dictionary will be used to compute metrics.
features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.
target: (torch.Tensor): Ground truth data to evaluate predictions against.
Returns:
EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and
additional losses indexed by name.
"""
# Check that the model is in evaluation mode
assert not self.model.training
# If there are no old local models in the list (first pass of MOON training), we just do basic loss
# calculations
if len(self.old_models_list) == 0:
checkpoint_loss, additional_losses = super().compute_loss_and_additional_losses(preds, features, target)
else:
_, additional_losses = self.compute_loss_and_additional_losses(preds, features, target)
# Note that the first parameter returned is the "total loss", which includes the contrastive loss
# So we use the vanilla loss stored in additional losses for checkpointing.
checkpoint_loss = additional_losses["loss"]
return EvaluationLosses(checkpoint=checkpoint_loss, additional_losses=additional_losses)