fl4health.clients.flexible_client module¶
- class FlexibleClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Flexible FL Client with functionality to train, evaluate, log, report and checkpoint.
FlexibleClient is similar to BasicClient but provides added flexibility through the ability to inject models and optimizers in the methods responsible for making predictions and performing both train and validation steps.
This added flexibility allows for FlexibleClient to be automatically adapted with our personalized methods: ~fl4health.mixins.personalized.
As with BasicClient, users are responsible for implementing methods:
get_model
get_optimizer
get_data_loaders
,get_criterion
However, unlike BasicClient, users looking to specialize logic for making predictions, and performing train and validation steps, should instead override:
predict_with_model
_train_step_with_model_and_optimizer
(and its delegated helpers)_val_step_with_model
Other methods can be overridden to achieve custom functionality.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to
LossMeterType.AVERAGE
.checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses
tqdm
. Defaults to Falseclient_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
- predict_with_model(model, input)[source]¶
Helper predict method that allows for injection of model.
NOTE: Subclasses should implement this method if there is need to specialize the predict logic of the client.
- Parameters:
model (torch.nn.Module) – the model with which to make predictions
input (TorchInputType) – Inputs to be fed into the model. If input is of type
dict[str, torch.Tensor]
, it is assumed that the keys of input match the names of the keyword arguments of ``self.model.forward().`
- Returns:
A tuple in which the first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute losses such as the contrastive loss in MOON. All predictions included in dictionary will by default be used to compute metrics separately.
- Return type:
tuple[TorchPredType, TorchFeatureType]
- Raises:
TypeError – Occurs when something other than a tensor or dict of tensors is passed in to the model’s forward method.
ValueError – Occurs when something other than a tensor or dict of tensors is returned by the model forward.
- train_step(input, target)[source]¶
Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (i.e. backprop on a single batch of data). Assumes
self.model
is in train mode already.- Parameters:
input (TorchInputType) – The input to be fed into the model.
target (TorchTargetType) – The target corresponding to the input.
- Returns:
The losses object from the train step along with a dictionary of any predictions produced by the model.
- Return type:
tuple[TrainingLosses, TorchPredType]
- transform_gradients(losses)[source]¶
Hook function for model training only called after backwards pass but before optimizer step. Useful for transforming the gradients (such as with gradient clipping) before they are applied to the model weights.
- Parameters:
losses (TrainingLosses) – The losses object from the train step
- Return type:
- val_step(input, target)[source]¶
Given input and target, compute loss, update loss and metrics. Assumes
self.model
is in eval mode already.- Parameters:
input (TorchInputType) – The input to be fed into the model.
target (TorchTargetType) – The target corresponding to the input.
- Returns:
The losses object from the val step along with a dictionary of the predictions produced by the model.
- Return type:
tuple[EvaluationLosses, TorchPredType]