fl4health.clients.adaptive_drift_constraint_client module¶
- class AdaptiveDriftConstraintClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]¶
This client serves as a base for FL methods implementing an auxiliary loss penalty with a weight coefficient that might be adapted via loss trajectories on the server-side. An example of such a method is FedProx, which uses an auxiliary loss penalizing weight drift with a coefficient mu. This client is a simple extension of the BasicClient that packs the self.loss_for_adaptation for exchange with the server and expects to receive an updated (or constant if non-adaptive) parameter for the loss weight. In many cases, such as FedProx, the loss_for_adaptation being packaged is the criterion loss (i.e. loss without the penalty)
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False
client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.
- compute_penalty_loss()[source]¶
Computes the drift loss for the client model and drift tensors
- Returns:
Computed penalty loss tensor
- Return type:
torch.Tensor
- compute_training_loss(preds, features, target)[source]¶
Computes training loss given predictions of the model and ground truth data. Adds to objective by including penalty loss.
- Parameters:
preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (TorchTargetType): Ground truth data to evaluate predictions against.
- Returns:
- an instance of TrainingLosses containing backward loss and additional losses indexed
by name. Additional losses includes penalty loss.
- Return type:
- get_parameter_exchanger(config)[source]¶
Setting up the parameter exchanger to include the appropriate packing functionality. By default we assume that we’re exchanging all parameters. Can be overridden for other behavior :type config:
Dict
[str
,Union
[bool
,bytes
,float
,int
,str
]] :param config: The config is sent by the FL server to allow for customization in the function if desired. :type config: Config- Returns:
Exchanger that can handle packing/unpacking auxiliary server information.
- Return type:
- get_parameters(config)[source]¶
Packs the parameters and training loss into a single NDArrays to be sent to the server for aggregation. If the client has not been initialized, this means the server is requesting parameters for initialization and just the model parameters are sent. When using the FedAvgWithAdaptiveConstraint strategy, this should not happen, as that strategy requires server-side initialization parameters. However, other strategies may handle this case.
- Parameters:
config (Config) – Configurations to allow for customization of this functions behavior
- Returns:
Parameters and training loss packed together into a list of numpy arrays to be sent to the server
- Return type:
NDArrays
- set_parameters(parameters, config, fitting_round)[source]¶
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are unpacked for the clients to use in training. In the first fitting round, we assume the full model is being initialized and use the FullParameterExchanger() to set all model weights. :type parameters:
List
[ndarray
[Any
,dtype
[Any
]]] :param parameters: Parameters have information about model state to be added to the relevant clientmodel and also the penalty weight to be applied during training.
- Parameters:
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. A full parameter exchanger is always used if the current federated learning round is the very first fitting round.
- Return type: