fl4health.clients.fedrep_client module

class FedRepClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: BasicClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Client implementing the training of FedRep (https://arxiv.org/abs/2303.05206).

Similar to FedPer, FedRep trains a global feature extractor shared by all clients through FedAvg and a private classifier that is unique to each client. However, FedRep breaks up the client-side training of these components into two phases. First the local classifier is trained with the feature extractor frozen. Next, the classifier is frozen and the feature extractor is trained.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

fit(parameters, config)[source]

Processes config, initializes client (if first round) and performs training based on the passed config. For FedRep, this coordinates calling the right training functions based on the passed steps. We need to override the functionality of the BasicClient to allow for two distinct training passes of the model, as required by FedRep.

Parameters:
  • parameters (NDArrays) – The parameters of the model to be used in fit.

  • config (NDArrays) – The config from the server.

Returns:

The parameters following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit.

Return type:

tuple[NDArrays, int, dict[str, Scalar]]

Raises:

ValueError – If the steps or epochs for the representation and head module training processes are are correctly specified.

get_optimizer(config)[source]

Returns a dictionary with global and local optimizers with string keys “representation” and “head,” respectively.

Return type:

dict[str, Optimizer]

get_parameter_exchanger(config)[source]

Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.

Parameters:

config (Config) – The config from server.

Returns:

Used to exchange parameters between server and client.

Return type:

ParameterExchanger

process_fed_rep_config(config)[source]

Method to ensure the required keys are present in config and extracts values to be returned. We override this functionality from the BasicClient, because FedRep has slightly different requirements. That is, one needs to specify a number of epochs or steps to do for BOTH the head module AND the representation module.

Parameters:

config (Config) – The config from the server.

Returns:

Returns the local_epochs, local_steps, current_server_round and evaluate_after_fit. Ensures only one of local_epochs and local_steps is defined in the config and sets the one that is not to None.

Return type:

tuple[int | None, int | None, int | None, int | None, int, bool]

Raises:

ValueError – If the config contains both local_steps and local epochs or if local_steps, local_epochs or current_server_round is of the wrong type (int).

set_optimizer(config)[source]

FedRep requires an optimizer for the representations optimization and one for the model head. This function simply ensures that the optimizers setup by the user have the proper keys and that there are two optimizers.

Parameters:

config (Config) – The config from the server.

Return type:

None

train_fedrep_by_epochs(head_epochs, rep_epochs, current_round=None)[source]

Train locally for the specified number of epochs.

Parameters:
  • epochs (int) – The number of epochs for local training.

  • current_round (int | None) – The current FL round.

Returns:

The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss.

Return type:

tuple[dict[str, float], dict[str, Scalar]]

train_fedrep_by_steps(head_steps, rep_steps, current_round=None)[source]

Train locally for the specified number of steps.

Parameters:

steps (int) – The number of steps to train locally.

Returns:

The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss.

Return type:

tuple[dict[str, float], dict[str, Scalar]]

train_step(input, target)[source]

Mechanics of training loop follow the FedRep paper: https://arxiv.org/pdf/2102.07078.pdf. In order to reuse the train_step functionality, we switch between the appropriate optimizers depending on the clients training mode (HEAD vs. REPRESENTATION)

Parameters:
  • input (TorchInputType) – input tensor to be run through the model. Here, TorchInputType is simply an alias for the union of torch.Tensor and dict[str, torch.Tensor].

  • target (torch.Tensor) – target tensor to be used to compute a loss given the model’s outputs.

Returns:

The losses object from the train step along with a dictionary of any predictions produced by the model.

Return type:

tuple[TrainingLosses, dict[str, torch.Tensor]]

class FedRepTrainMode(value)[source]

Bases: Enum

An enumeration.

HEAD = 'head'
REPRESENTATION = 'representation'