fl4health.clients.basic_client module

class BasicClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: NumPyClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Base FL Client with functionality to train, evaluate, log, report and checkpoint. User is responsible for implementing methods: get_model, get_optimizer, get_data_loaders, get_criterion Other methods can be overridden to achieve custom functionality.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

compute_evaluation_loss(preds, features, target)[source]

Computes evaluation loss given predictions (and potentially features) of the model and ground truth data.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

an instance of EvaluationLosses containing checkpoint loss and additional losses

indexed by name.

Return type:

EvaluationLosses

compute_loss_and_additional_losses(preds, features, target)[source]

Computes the loss and any additional losses given predictions of the model and ground truth data.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name.

  • features (TorchFeatureType) – Feature(s) of the model(s) indexed by name.

  • target (TorchTargetType) – Ground truth data to evaluate predictions against.

Returns:

  • The tensor for the loss

  • A dictionary of additional losses with their names and values, or None if

    there are no additional losses.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]; A tuple with

compute_training_loss(preds, features, target)[source]

Computes training loss given predictions (and potentially features) of the model and ground truth data.

Parameters:
  • preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.

  • features (dict[str, Tensor]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (TorchTargetType): Ground truth data to evaluate predictions against.

Returns:

an instance of TrainingLosses containing backward loss and additional losses

indexed by name.

Return type:

TrainingLosses

evaluate(parameters, config)[source]

Evaluates the model on the validation set, and test set (if defined).

Parameters:
  • parameters (NDArrays) – The parameters of the model to be evaluated.

  • config (NDArrays) – The config object from the server.

Returns:

A loss associated with the evaluation, the number of samples in the

validation/test set and the metric_values associated with evaluation.

Return type:

tuple[float, int, dict[str, Scalar]]

fit(parameters, config)[source]

Processes config, initializes client (if first round) and performs training based on the passed config. If per_round_checkpointer is not None, on initialization the client checks if a checkpointed client state exists to load and at the end of each round the client state is saved.

Parameters:
  • parameters (NDArrays) – The parameters of the model to be used in fit.

  • config (NDArrays) – The config from the server.

Returns:

The parameters following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit.

Return type:

tuple[NDArrays, int, dict[str, Scalar]]

Raises:

ValueError – If local_steps or local_epochs is not specified in config.

get_client_specific_logs(current_round, current_epoch, logging_mode)[source]

This function can be overridden to provide any client specific information to the basic client logging. For example, perhaps a client uses an LR scheduler and wants the LR to be logged each epoch. Called at the beginning and end of each server round or local epoch. Also called at the end of validation/testing.

Parameters:
  • current_round (int | None) – The current FL round (i.e., current server round).

  • current_epoch (int | None) – The current epoch of local training.

  • logging_mode (LoggingMode) – The logging mode (Training, Validation, or Testing).

Returns:

A string to append to the header log string that

typically announces the current server round and current epoch at the beginning of each round or local epoch.

list[tuple[LogLevel, str]]] | None: A list of tuples where the

first element is a LogLevel as defined in fl4health.utils. typing and the second element is a string message. Each item in the list will be logged at the end of each server round or epoch. Elements will also be logged at the end of validation/testing.

Return type:

str | None

get_client_specific_reports()[source]

This function can be overridden by an inheriting client to report additional client specific information to the wandb_reporter

Returns:

A dictionary of things to report

Return type:

dict[str, Any]

get_criterion(config)[source]

User defined method that returns PyTorch loss to train model.

Parameters:

config (Config) – The config from the server.

Raises:

NotImplementedError – To be defined in child class.

Return type:

_Loss

get_data_loaders(config)[source]

User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader

Parameters:

config (Config) – The config from the server.

Returns:

Tuple of length 2. The client train and validation loader.

Return type:

tuple[DataLoader, …]

Raises:

NotImplementedError – To be defined in child class.

get_lr_scheduler(optimizer_key, config)[source]

Optional user defined method that returns learning rate scheduler to be used throughout training for the given optimizer. Defaults to None.

Parameters:
  • optimizer_key (str) – The key in the optimizer dict corresponding to the optimizer we are optionally defining a learning rate scheduler for.

  • config (Config) – The config from the server.

Returns:

Client learning rate schedulers.

Return type:

LRScheduler | None

get_model(config)[source]

User defined method that returns PyTorch model.

Parameters:

config (Config) – The config from the server.

Returns:

The client model.

Return type:

nn.Module

Raises:

NotImplementedError – To be defined in child class.

get_optimizer(config)[source]

Method to be defined by user that returns the PyTorch optimizer used to train models locally Return value can be a single torch optimizer or a dictionary of string and torch optimizer. Returning multiple optimizers is useful in methods like APFL which has a different optimizer for the local and global models.

Parameters:

config (Config) – The config sent from the server.

Returns:

An optimizer or dictionary of optimizers to train model.

Return type:

Optimizer | dict[str, Optimizer]

Raises:

NotImplementedError – To be defined in child class.

get_parameter_exchanger(config)[source]

Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.

Parameters:

config (Config) – The config from server.

Returns:

Used to exchange parameters between server and client.

Return type:

ParameterExchanger

get_parameters(config)[source]

Determines which parameters are sent back to the server for aggregation. This uses a parameter exchanger to determine parameters sent.

Parameters:

config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Returns:

These are the parameters to be sent to the server. At minimum they represent the relevant model

parameters to be aggregated, but can contain more information.

Return type:

NDArrays

get_properties(config)[source]

Return properties (train and validation dataset sample counts) of client.

Parameters:

config (Config) – The config from the server.

Returns:

A dictionary with two entries corresponding to the sample counts in

the train and validation set.

Return type:

dict[str, Scalar]

get_test_data_loader(config)[source]

User defined method that returns a PyTorch Test DataLoader. By default, this function returns None, assuming that there is no test dataset to be used. If the user would like to load and evaluate a dataset,

they need only override this function in their client class.

Parameters:

config (Config) – The config from the server.

Return type:

DataLoader | None

Returns:

DataLoader | None. The optional client test loader. Returns None.

initialize_all_model_weights(parameters, config)[source]

If this is the first time we’re initializing the model weights, we use the FullParameterExchanger to initialize all model components. Subclasses that require custom model initialization can override this.

Parameters:
  • parameters (NDArrays) – Model parameters to be injected into the client model

  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

Return type:

None

predict(input)[source]

Computes the prediction(s), and potentially features, of the model(s) given the input.

Parameters:

input (TorchInputType) – Inputs to be fed into the model. If input is of type dict[str, torch.Tensor], it is assumed that the keys of input match the names of the keyword arguments of self.model. forward().

Returns:

A tuple in which the

first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute losses such as the contrastive loss in MOON. All predictions included in dictionary will by default be used to compute metrics separately.

Return type:

tuple[TorchPredType, TorchFeatureType]

Raises:
  • TypeError – Occurs when something other than a tensor or dict of tensors is passed in to the model’s

  • forward method.

  • ValueError – Occurs when something other than a tensor or dict of tensors is returned by the model

  • forward.

process_config(config)[source]

Method to ensure the required keys are present in config and extracts values to be returned.

Parameters:

config (Config) – The config from the server.

Returns:

Returns the local_epochs, local_steps,

current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of local_epochs and local_steps is defined in the config and sets the one that is not to None.

Return type:

tuple[int | None, int | None, int, bool, bool]

Raises:

ValueError – If the config contains both local_steps and local epochs or if local_steps, local_epochs or current_server_round is of the wrong type (int).

set_optimizer(config)[source]

Method called in the the setup_client method to set optimizer attribute returned by used-defined get_optimizer. In the simplest case, get_optimizer returns an optimizer. For more advanced use cases where a dictionary of string and optimizer are returned (ie APFL), the user must override this method.

Parameters:

config (Config) – The config from the server.

Return type:

None

set_parameters(parameters, config, fitting_round)[source]

Sets the local model parameters transferred from the server using a parameter exchanger to coordinate how parameters are set. In the first fitting round, we assume the full model is being initialized and use the FullParameterExchanger() to set all model weights. Otherwise, we use the appropriate parameter exchanger defined by the user depending on the federated learning algorithm being used.

Parameters:
  • parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model but may contain more information than that.

  • config (Config) – The config is sent by the FL server to allow for customization in the function if desired.

  • fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. A full parameter exchanger is only used if the current federated learning round is the very first fitting round.

Return type:

None

setup_client(config)[source]

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True.

Parameters:

config (Config) – The config from the server.

Return type:

None

shutdown()[source]

Shuts down the client. Involves shutting down W&B reporter if one exists.

Return type:

None

train_by_epochs(epochs, current_round=None)[source]

Train locally for the specified number of epochs.

Parameters:
  • epochs (int) – The number of epochs for local training.

  • current_round (int | None, optional) – The current FL round.

Returns:

The loss and metrics dictionary from the local training.

Loss is a dictionary of one or more losses that represent the different components of the loss.

Return type:

tuple[dict[str, float], dict[str, Scalar]]

train_by_steps(steps, current_round=None)[source]

Train locally for the specified number of steps.

Parameters:
  • steps (int) – The number of steps to train locally.

  • current_round (int | None, optional) – The current FL round

Returns:

The loss and metrics dictionary from the local training.

Loss is a dictionary of one or more losses that represent the different components of the loss.

Return type:

tuple[dict[str, float], dict[str, Scalar]]

train_step(input, target)[source]

Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (ie backprop on a single batch of data). Assumes self.model is in train mode already.

Parameters:
  • input (TorchInputType) – The input to be fed into the model.

  • target (TorchTargetType) – The target corresponding to the input.

Returns:

The losses object from the train step along with

a dictionary of any predictions produced by the model.

Return type:

tuple[TrainingLosses, TorchPredType]

transform_gradients(losses)[source]

Hook function for model training only called after backwards pass but before optimizer step. Useful for transforming the gradients (such as with gradient clipping) before they are applied to the model weights.

Parameters:

losses (TrainingLosses) – The losses object from the train step

Return type:

None

transform_target(target)[source]

Method that users can extend to specify an arbitrary transformation to apply to the target prior to the loss being computed. Defaults to the identity transform.

Overriding this method can be useful in a variety of scenarios such as Self Supervised Learning where the target is derived from the input sample itself. For example, the FedSimClr reference implementation overrides this method to extract features from the target, which is a transformed version of the input image itself.

Parameters:

target (TorchTargetType) – The target or label used to compute the loss.

Returns:

Identical to target.

Return type:

TorchTargetType

update_after_step(step, current_round=None)[source]

Hook method called after local train step on client. step is an integer that represents the local training step that was most recently completed. For example, used by the APFL method to update the alpha value after a training a step. Also used by the MOON, FENDA and Ditto to update optimized beta value for MK-MMD loss after n steps.

Parameters:
  • step (int) – The step number in local training that was most recently completed. Resets only at the end of the round.

  • current_round (int | None, optional) – The current FL server round

Return type:

None

update_after_train(local_steps, loss_dict, config)[source]

Hook method called after training with the number of local_steps performed over the FL round and the corresponding loss dictionary. For example, used by Scaffold to update the control variates after a local round of training. Also used by FedProx to update the current loss based on the loss returned during training. Also used by MOON and FENDA to save trained modules weights before aggregation.

Parameters:
  • local_steps (int) – The number of steps so far in the round in the local training.

  • loss_dict (dict[str, float]) – A dictionary of losses from local training.

  • config (Config) – The config from the server

Return type:

None

update_before_epoch(epoch)[source]

Hook method called before local epoch on client. Only called if client is being trained by epochs (ie. using local_epochs key instead of local steps in the server config file)

Parameters:

epoch (int) – Integer representing the epoch about to begin

Return type:

None

update_before_step(step, current_round=None)[source]

Hook method called before local train step.

Parameters:
  • step (int) – The local training step that was most recently completed. Resets only at the end of the round.

  • current_round (int | None, optional) – The current FL server round

Return type:

None

update_before_train(current_server_round)[source]

Hook method called before training with the number of current server rounds performed. NOTE: This method is called immediately AFTER the aggregated parameters are received from the server. For example, used by MOON and FENDA to save global modules after aggregation.

Parameters:

current_server_round (int) – The number of current server round.

Return type:

None

update_lr_schedulers(step=None, epoch=None)[source]
Updates any schedulers that exist. Can be overridden to customize update logic based on client state

(ie self.total_steps).

Parameters:
  • step (int | None) – If using local_steps, current step of this round. Otherwise None.

  • epoch (int | None) – If using local_epochs current epoch of this round. Otherwise None.

Return type:

None

update_metric_manager(preds, target, metric_manager)[source]

Updates a metric manager with the provided model predictions and targets. Can be overridden to modify pred and target inputs to the metric manager. This is useful in cases where the preds and targets needed to compute the loss are different than what is needed to compute metrics.

Parameters:
  • preds (TorchPredType) – The output predictions from the model returned by self.predict

  • target (TorchTargetType) – The targets generated by the dataloader to to evaluate the predictions with

  • metric_manager (MetricManager) – The metric manager to update

Return type:

None

val_step(input, target)[source]

Given input and target, compute loss, update loss and metrics. Assumes self.model is in eval mode already.

Parameters:
  • input (TorchInputType) – The input to be fed into the model.

  • target (TorchTargetType) – The target corresponding to the input.

Returns:

The losses object from the val step along with a dictionary of the predictions produced by the model.

Return type:

tuple[EvaluationLosses, TorchPredType]

validate(include_losses_in_metrics=False)[source]
Validate the current model on the entire validation

and potentially an entire test dataset if it has been defined.

Returns:

The validation loss and a dictionary of metrics

from validation (and test if present).

Return type:

tuple[float, dict[str, Scalar]]