Source code for fl4health.clients.flash_client

from collections.abc import Sequence
from logging import INFO
from pathlib import Path

import torch
from flwr.common.logger import log
from flwr.common.typing import Config, Scalar

from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.basic_client import BasicClient
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.client import check_if_batch_is_empty_and_verify_input, move_data_to_device
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Metric


[docs] class FlashClient(BasicClient):
[docs] def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, client_name: str | None = None, ) -> None: """ This client is used to perform client-side training associated with the Flash method described in https://proceedings.mlr.press/v202/panchal23a/panchal23a.pdf. Flash is designed to handle statistical heterogeneity and concept drift in federated learning through client-side early stopping and server-side drift-aware adaptive optimization. Args: data_path (Path): path to the data to be used to load the data for client-side training metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, client_name=client_name, ) # gamma: Threshold for early stopping based on the change in validation loss. self.gamma: float | None = None
[docs] def process_config(self, config: Config) -> tuple[int | None, int | None, int, bool, bool]: local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics = ( super().process_config(config) ) if local_steps is not None: raise ValueError( "Training by steps is not applicable for FLASH clients.\ Please define 'local_epochs' in your config instead" ) return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics
[docs] def train_by_epochs( self, epochs: int, current_round: int | None = None ) -> tuple[dict[str, float], dict[str, Scalar]]: self.model.train() local_step = 0 previous_loss = float("inf") report_data: dict = {"round": current_round} for local_epoch in range(epochs): self.train_metric_manager.clear() self.train_loss_meter.clear() self._log_header_str(current_round, local_epoch) report_data.update({"fit_epoch": local_epoch}) for input, target in self.train_loader: if check_if_batch_is_empty_and_verify_input(input): log(INFO, "Empty batch generated by data loader. Skipping step.") continue input = move_data_to_device(input, self.device) target = move_data_to_device(target, self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_manager.update(preds, target) self.update_after_step(local_step, current_round) report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps}) report_data.update(self.get_client_specific_reports()) self.reports_manager.report(report_data, current_round, local_epoch, self.total_steps) self.total_steps += 1 local_step += 1 metrics = self.train_metric_manager.compute() loss_dict = self.train_loss_meter.compute().as_dict() current_loss, _ = self.validate() self._log_results( loss_dict, metrics, current_round=current_round, current_epoch=local_epoch, ) if self.gamma is not None and previous_loss - current_loss < self.gamma / (local_epoch + 1): log( INFO, f"Early stopping at epoch {local_epoch} with loss change {abs(previous_loss - current_loss)}\ and gamma {self.gamma}", ) break previous_loss = current_loss return loss_dict, metrics
[docs] def setup_client(self, config: Config) -> None: super().setup_client(config) if "gamma" in config: self.gamma = narrow_dict_type(config, "gamma", float) else: log(INFO, "Gamma not present in config. Early stopping is disabled.")