fl4health.clients.flash_client module¶
- class FlashClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
This client is used to perform client-side training associated with the Flash method described in https://proceedings.mlr.press/v202/panchal23a/panchal23a.pdf.
Flash is designed to handle statistical heterogeneity and concept drift in federated learning through client-side early stopping and server-side drift-aware adaptive optimization.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to
LossMeterType.AVERAGE
.checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses
tqdm
. Defaults to Falseclient_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
- process_config(config)[source]¶
Performs the necessary processing of the config from the server. FLASH is not defined for step-wise training. So this method straps on a check to ensure that we aren’t trying to do step-wise training
- Parameters:
config (Config) – The config object from the server.
- Raises:
ValueError – Throws if the user is attempting to train by steps instead of epochs for this method
- Returns:
Returns the
local_epochs
,local_steps
,current_server_round
,evaluate_after_fit
andpack_losses_with_val_metrics
. Ensures only one oflocal_epochs
andlocal_steps
is defined in the config and sets the one that is not to None.- Return type:
- setup_client(config)[source]¶
Follows the same flow as BasicClient for setting up the client. This function simply performs an additional step to process whether the gamma parameter is in the configuration
- Parameters:
config (Config) – The config object from the server.
- Return type:
- train_by_epochs(epochs, current_round=None)[source]¶
Implements a custom train_by_epochs for this client to allow for the FLASH adaptations on the client side. If gamma is None, then this function works exactly as the
BasicClient
. Otherwise, we perform epochs and possibly stop early using gamma as a threshold.- Parameters:
- Returns:
The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss.
- Return type: