Source code for fl4health.clients.constrained_fenda_client

from collections.abc import Sequence
from logging import WARNING
from pathlib import Path

import torch
import torch.nn as nn
from flwr.common.logger import log
from flwr.common.typing import Config

from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.fenda_client import FendaClient
from fl4health.losses.fenda_loss_config import ConstrainedFendaLossContainer
from fl4health.model_bases.fenda_base import FendaModelWithFeatureState
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.client import clone_and_freeze_model
from fl4health.utils.losses import EvaluationLosses, LossMeterType
from fl4health.utils.metrics import Metric
from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType


[docs] class ConstrainedFendaClient(FendaClient):
[docs] def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, client_name: str | None = None, loss_container: ConstrainedFendaLossContainer | None = None, ) -> None: """ This class extends the functionality of FENDA training to include various kinds of constraints applied during the client-side training of FENDA models. Args: data_path (Path): path to the data to be used to load the data for client-side training metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. loss_container (ConstrainedFendaLossContainer | None, optional): Configuration that determines which losses will be applied during FENDA training. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, client_name=client_name, ) if loss_container: self.loss_container = loss_container else: # If no loss configuration has been define, set everything to zero. This is equivalent to vanilla FENDA log( WARNING, "No loss container provided, defaulting to an empty container. " "This is equivalent to running a vanilla FENDA client", ) self.loss_container = ConstrainedFendaLossContainer(None, None, None) # Need to save previous local module, global module and aggregated global module at each communication round # to compute contrastive loss. self.old_local_module: nn.Module | None = None self.old_global_module: nn.Module | None = None self.initial_global_module: nn.Module | None = None
[docs] def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: assert isinstance(self.model, FendaModelWithFeatureState) return super().get_parameter_exchanger(config)
def _flatten(self, features: torch.Tensor) -> torch.Tensor: """ Flatten the provided features ASSUMING they are provided in batch-first format. Args: features (torch.Tensor): features to be flattened Returns: torch.Tensor: flattened feature vectors of shape (batch, -1) """ return features.reshape(len(features), -1) def _perfcl_keys_present(self, features: dict[str, torch.Tensor]) -> bool: target_keys = { "old_local_features", "old_global_features", "initial_global_features", } return target_keys.issubset(features.keys())
[docs] def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the prediction(s) and features of the model(s) given the input. Args: input (TorchInputType): Inputs to be fed into the model. TorchInputType is simply an alias for the union of torch.Tensor and dict[str, torch.Tensor]. Returns: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics. """ assert isinstance(input, torch.Tensor) assert isinstance(self.model, FendaModelWithFeatureState) preds, features = self.model(input) if self.loss_container.has_contrastive_loss() or self.loss_container.has_perfcl_loss(): # If we have defined a contrastive loss function or PerFCL loss function, we attempt to save old local # features. if self.old_local_module is not None: features["old_local_features"] = self._flatten(self.old_local_module.forward(input)) if self.loss_container.has_perfcl_loss(): # If a PerFCL loss function has been defined, then we also save two additional feature components. if self.old_global_module is not None: features["old_global_features"] = self._flatten(self.old_global_module.forward(input)) if self.initial_global_module is not None: features["initial_global_features"] = self._flatten(self.initial_global_module.forward(input)) return preds, features
[docs] def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ This function is called after client-side training concludes. If a contrastive or PerFCL loss function has been defined, it is used to save the local and global feature extraction weights/modules to be used in the next round of client-side training. Args: local_steps (int): Number of steps performed during training loss_dict (dict[str, float]): Losses computed during training. """ # Save the parameters of the old model assert isinstance(self.model, FendaModelWithFeatureState) if self.loss_container.has_contrastive_loss() or self.loss_container.has_perfcl_loss(): self.old_local_module = clone_and_freeze_model(self.model.first_feature_extractor) self.old_global_module = clone_and_freeze_model(self.model.second_feature_extractor) super().update_after_train(local_steps, loss_dict, config)
[docs] def update_before_train(self, current_server_round: int) -> None: """ This function is called prior to the start of client-side training, but after the server parameters have be received and injected into the model. If a PerFCL loss function has been defined, it is used to save the aggregated global feature extractor weights/module representing the initial state of this module BEFORE this iteration of client-side training but AFTER server-side aggregation. Args: current_server_round (int): Current server round being performed. """ # Save the parameters of the aggregated global model assert isinstance(self.model, FendaModelWithFeatureState) if self.loss_container.has_perfcl_loss(): self.initial_global_module = clone_and_freeze_model(self.model.second_feature_extractor) super().update_before_train(current_server_round)
[docs] def compute_loss_and_additional_losses( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. For FENDA, the loss is the total loss and the additional losses are the loss, total loss and, based on client attributes set from server config, cosine similarity loss, contrastive loss and perfcl losses. Args: preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target (torch.Tensor): Ground truth data to evaluate predictions against. Returns: tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with: - The tensor for the total loss - A dictionary with `loss`, `total_loss` and, based on client attributes set from server config, also `cos_sim_loss`, `contrastive_loss`, `contrastive_loss_minimize` and `contrastive_loss_minimize` keys and their respective calculated values. """ loss = self.criterion(preds["prediction"], target) total_loss = loss.clone() additional_losses = {"loss": loss} if self.loss_container.has_cosine_similarity_loss(): cosine_similarity_loss = self.loss_container.compute_cosine_similarity_loss( features["local_features"], features["global_features"] ) total_loss += cosine_similarity_loss additional_losses["cos_sim_loss"] = cosine_similarity_loss if self.loss_container.has_contrastive_loss() and "old_local_features" in features: contrastive_loss = self.loss_container.compute_contrastive_loss( features["local_features"], features["old_local_features"].unsqueeze(0), features["global_features"].unsqueeze(0), ) total_loss += contrastive_loss additional_losses["contrastive_loss"] = contrastive_loss if self.loss_container.has_perfcl_loss() and self._perfcl_keys_present(features): global_feature_contrastive_loss, local_feature_contrastive_loss = self.loss_container.compute_perfcl_loss( features["local_features"], features["old_local_features"], features["global_features"], features["old_global_features"], features["initial_global_features"], ) total_loss += global_feature_contrastive_loss + local_feature_contrastive_loss additional_losses["global_feature_contrastive_loss"] = global_feature_contrastive_loss additional_losses["local_feature_contrastive_loss"] = local_feature_contrastive_loss additional_losses["total_loss"] = total_loss return total_loss, additional_losses
[docs] def compute_evaluation_loss( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, ) -> EvaluationLosses: """ Computes evaluation loss given predictions of the model and ground truth data. Optionally computes additional loss components such as cosine_similarity_loss, contrastive_loss and perfcl_loss based on client attributes set from server config. Args: preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics. features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses indexed by name. Additional losses may include cosine_similarity_loss, contrastive_loss and perfcl_loss. """ _, additional_losses = self.compute_loss_and_additional_losses(preds, features, target) return EvaluationLosses(checkpoint=additional_losses["loss"], additional_losses=additional_losses)