fl4health.clients.flash_client module

class FlashClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: BasicClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

This client is used to perform client-side training associated with the Flash method described in https://proceedings.mlr.press/v202/panchal23a/panchal23a.pdf.

Flash is designed to handle statistical heterogeneity and concept drift in federated learning through client-side early stopping and server-side drift-aware adaptive optimization.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

process_config(config)[source]

Performs the necessary processing of the config from the server. FLASH is not defined for step-wise training. So this method straps on a check to ensure that we aren’t trying to do step-wise training

Parameters:

config (Config) – The config object from the server.

Raises:

ValueError – Throws if the user is attempting to train by steps instead of epochs for this method

Returns:

Returns the local_epochs, local_steps, current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of local_epochs and local_steps is defined in the config and sets the one that is not to None.

Return type:

tuple[int | None, int | None, int, bool, bool]

setup_client(config)[source]

Follows the same flow as BasicClient for setting up the client. This function simply performs an additional step to process whether the gamma parameter is in the configuration

Parameters:

config (Config) – The config object from the server.

Return type:

None

train_by_epochs(epochs, current_round=None)[source]

Implements a custom train_by_epochs for this client to allow for the FLASH adaptations on the client side. If gamma is None, then this function works exactly as the BasicClient. Otherwise, we perform epochs and possibly stop early using gamma as a threshold.

Parameters:
  • epochs (int) – Number of epochs to train

  • current_round (int | None, optional) – Current server round being performed. Defaults to None.

Returns:

The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss.

Return type:

tuple[dict[str, float], dict[str, Scalar]]