fl4health.mixins package¶
- class AdaptiveDriftConstrainedMixin(*args, **kwargs)[source]¶
Bases:
object
- __init__(*args, **kwargs)[source]¶
Adaptive Drift Constrained Mixin
To be used with ~fl4health.BaseClient in order to add the ability to compute losses via a constrained adaptive drift.
- Raises:
RuntimeError – when the inheriting class does not satisfy BasicClientProtocolPreSetup.
- compute_penalty_loss()[source]¶
Computes the drift loss for the client model and drift tensors
- Returns:
Computed penalty loss tensor
- Return type:
torch.Tensor
- compute_training_loss(preds, features, target)[source]¶
Computes training loss given predictions of the model and ground truth data. Adds to objective by including penalty loss.
- Parameters:
preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (TorchTargetType): Ground truth data to evaluate predictions against.
- Returns:
An instance of
TrainingLosses
containing backward loss and additional losses indexed by name. Additional losses includes penalty loss.- Return type:
- get_parameter_exchanger(config)[source]¶
Setting up the parameter exchanger to include the appropriate packing functionality. By default we assume that we’re exchanging all parameters. Can be overridden for other behavior
- Parameters:
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
- Returns:
Exchanger that can handle packing/unpacking auxiliary server information.
- Return type:
- get_parameters(config)[source]¶
Packs the parameters and training loss into a single
NDArrays
to be sent to the server for aggregation. If the client has not been initialized, this means the server is requesting parameters for initialization and just the model parameters are sent. When using theFedAvgWithAdaptiveConstraint
strategy, this should not happen, as that strategy requires server-side initialization parameters. However, other strategies may handle this case.- Parameters:
config (Config) – Configurations to allow for customization of this functions behavior
- Returns:
Parameters and training loss packed together into a list of numpy arrays to be sent to the server
- Return type:
NDArrays
- set_parameters(parameters, config, fitting_round)[source]¶
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are unpacked for the clients to use in training. In the first fitting round, we assume the full model is being initialized and use the
FullParameterExchanger()
to set all model weights.- Parameters:
parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model and also the penalty weight to be applied during training.
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. A full parameter exchanger is always used if the current federated learning round is the very first fitting round.
- Return type:
Submodules¶
- fl4health.mixins.adaptive_drift_constrained module
AdaptiveDriftConstrainedMixin
AdaptiveDriftConstrainedMixin.__init__()
AdaptiveDriftConstrainedMixin.compute_penalty_loss()
AdaptiveDriftConstrainedMixin.compute_training_loss()
AdaptiveDriftConstrainedMixin.get_parameter_exchanger()
AdaptiveDriftConstrainedMixin.get_parameters()
AdaptiveDriftConstrainedMixin.set_parameters()
AdaptiveDriftConstrainedMixin.update_after_train()
AdaptiveDriftConstrainedProtocol
AdaptiveDriftConstrainedProtocol.compute_penalty_loss()
AdaptiveDriftConstrainedProtocol.drift_penalty_tensors
AdaptiveDriftConstrainedProtocol.drift_penalty_weight
AdaptiveDriftConstrainedProtocol.loss_for_adaptation
AdaptiveDriftConstrainedProtocol.parameter_exchanger
AdaptiveDriftConstrainedProtocol.penalty_loss_function
apply_adaptive_drift_to_client()
- fl4health.mixins.core_protocols module
BasicClientProtocol
BasicClientProtocolPreSetup
BasicClientProtocolPreSetup.compute_loss_and_additional_losses()
BasicClientProtocolPreSetup.device
BasicClientProtocolPreSetup.get_criterion()
BasicClientProtocolPreSetup.get_data_loaders()
BasicClientProtocolPreSetup.get_model()
BasicClientProtocolPreSetup.get_optimizer()
BasicClientProtocolPreSetup.initialized
BasicClientProtocolPreSetup.setup_client()
NumPyClientMinimalProtocol