fl4health.clients.ensemble_client module¶
- class EnsembleClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
This client enables the training of ensemble models in a federated manner.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False
client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
- compute_evaluation_loss(preds, features, target)[source]¶
Computes evaluation loss given predictions (and potentially features) of the model and ground truth data. Since the ensemble client has more than one model, there are multiple backward losses that exist.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
- an instance of EvaluationLosses containing checkpoint loss and additional losses
indexed by name.
- Return type:
- compute_training_loss(preds, features, target)[source]¶
Computes training loss given predictions (and potentially features) of the model and ground truth data. Since the ensemble client has more than one model, there are multiple backward losses that exist.
- Parameters:
preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (torch.Tensor): Ground truth data to evaluate predictions against.
- Returns:
- an instance of TrainingLosses containing backward loss and additional losses
indexed by name.
- Return type:
- get_optimizer(config)[source]¶
Method to be defined by user that returns dictionary of optimizers with keys corresponding to the keys of the models in EnsembleModel that the optimizer applies too.
- Parameters:
config (Config) – The config sent from the server.
- Returns:
An optimizer or dictionary of optimizers to train model.
- Return type:
- Raises:
NotImplementedError – To be defined in child class.
- set_optimizer(config)[source]¶
Method called in the the setup_client method to set optimizer attribute returned by used-defined get_optimizer. Ensures that the return value of get_optimizer is a dictionary since that is required for the ensemble client.
- Parameters:
config (Config) – The config from the server.
- Return type:
- setup_client(config)[source]¶
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. Also perform some checks to ensure the keys of the optimizer dictionary are consistent with the model keys.
- Parameters:
config (Config) – The config from the server.
- Return type:
- train_step(input, target)[source]¶
Given a single batch of input and target data, generate predictions (both individual models and ensemble prediction), compute loss, update parameters and optionally update metrics if they exist. (ie backpropagation on a single batch of data). Assumes self.model is in train mode already. Differs from parent method in that, there are multiple losses that we have to do backward passes on and multiple optimizers to update parameters each train step.
- Parameters:
input (TorchInputType) – The input to be fed into the model.
and (TorchInputType is simply an alias for the union of torch.Tensor)
dict[str
torch.Tensor].
target (torch.Tensor) – The target corresponding to the input.
- Returns:
- The losses object from the train step along with
a dictionary of any predictions produced by the model.
- Return type:
tuple[TrainingLosses, dict[str, torch.Tensor]]