fl4health.clients.flash_client module

class FlashClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: BasicClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

This client is used to perform client-side training associated with the Flash method described in https://proceedings.mlr.press/v202/panchal23a/panchal23a.pdf. Flash is designed to handle statistical heterogeneity and concept drift in federated learning through client-side early stopping and server-side drift-aware adaptive optimization.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

process_config(config)[source]

Method to ensure the required keys are present in config and extracts values to be returned.

Parameters:

config (Config) – The config from the server.

Returns:

Returns the local_epochs, local_steps,

current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of local_epochs and local_steps is defined in the config and sets the one that is not to None.

Return type:

tuple[int | None, int | None, int, bool, bool]

Raises:

ValueError – If the config contains both local_steps and local epochs or if local_steps, local_epochs or current_server_round is of the wrong type (int).

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True.

Parameters:

config (Config) – The config from the server.

Return type:

None

train_by_epochs(epochs, current_round=None)[source]

Train locally for the specified number of epochs.

Parameters:
  • epochs (int) – The number of epochs for local training.

  • current_round (int | None, optional) – The current FL round.

Returns:

The loss and metrics dictionary from the local training.

Loss is a dictionary of one or more losses that represent the different components of the loss.

Return type:

tuple[dict[str, float], dict[str, Scalar]]