fl4health.mixins.personalized package¶
- class DittoPersonalizedMixin(*args, **kwargs)[source]¶
Bases:
AdaptiveDriftConstrainedMixin
- __init__(*args, **kwargs)[source]¶
This mixin implements the Ditto algorithm from Ditto: Fair and Robust Federated Learning Through Personalization. This mixin inherits from the
AdaptiveDriftConstrainedMixin
, and like that mixin, this should be mixed with aFlexibleClient
type in order to apply the Ditto personalization method to that client.Background Context:
The idea is that we want to train personalized versions of the global model for each client. So we simultaneously train a global model that is aggregated on the server-side and use those weights to also constrain the training of a local model. The constraint for this local model is identical to the FedProx loss.
- Raises:
RuntimeError – If the object does not satisfy the
FlexibleClientProtocolPreSetup
then it will raise an error. This is additional validation to ensure that the mixin was applied to an appropriate base class.
- compute_evaluation_loss(preds, features, target)[source]¶
Computes evaluation loss given predictions (and potentially features) of the model and ground truth data. For Ditto, we use the vanilla loss for the local model in checkpointing. However, during validation we also compute the global model vanilla loss.
- Parameters:
preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (TorchTargetType): Ground truth data to evaluate predictions against.
- Returns:
An instance of
EvaluationLosses
containing checkpoint loss and additional losses indexed by name.- Return type:
- get_global_model(config)[source]¶
Returns the global model to be used during Ditto training and as a constraint for the local model.
The global model should be the same architecture as the local model so we reuse the
get_model
call. We explicitly send the model to the desired device. This is idempotent.- Parameters:
config (Config) – The config from the server.
- Returns:
The PyTorch model serving as the global model for Ditto
- Return type:
nn.Module
- get_optimizer(config)[source]¶
Returns a dictionary with global and local optimizers with string keys “global” and “local” respectively.
- get_parameters(config)[source]¶
For Ditto, we transfer the GLOBAL model weights to the server to be aggregated. The local model weights stay with the client.
- Parameters:
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
- Returns:
GLOBAL model weights to be sent to the server for aggregation.
- Return type:
NDArrays
- initialize_all_model_weights(parameters, config)[source]¶
If this is the first time we’re initializing the model weights, we initialize both the global and the local weights together.
- Parameters:
parameters (NDArrays) – Model parameters to be injected into the client model.
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
- Return type:
- safe_global_model()[source]¶
Convenient accessor for the global model.
- Raises:
ValueError – If the
global_model
attribute has not yet been set, we will raise an error.- Returns:
the global model if it has been set.
- Return type:
nn.Module
- set_initial_global_tensors()[source]¶
Saving the initial GLOBAL MODEL weights and detaching them so that we don’t compute gradients with respect to the tensors. These are used to form the Ditto local update penalty term.
- Return type:
- set_optimizer(config)[source]¶
Ditto requires an optimizer for the global model and one for the local model. This function simply ensures that the optimizers setup by the user have the proper keys and that there are two optimizers.
- Parameters:
config (Config) – The config from the server.
- Return type:
- set_parameters(parameters, config, fitting_round)[source]¶
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are unpacked for the clients to use in training. The parameters being passed are to be routed to the global model. In the first fitting round, we assume the both the global and local models are being initialized and use the
FullParameterExchanger()
to initialize both sets of model weights to the same parameters.- Parameters:
parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model (global model for all but the first step of Ditto). These should also include a penalty weight from the server that needs to be unpacked.
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. If the current federated learning round is the very first fitting round, then we initialize both the global and local Ditto models with weights sent from the server.
- Return type:
- setup_client(config)[source]¶
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. In this class, this function simply adds the additional step of setting up the global model.
- Parameters:
config (Config) – The config from the server.
- Return type:
- train_step(input, target)[source]¶
Mechanics of training loop follow from original Ditto implementation: https://github.com/litian96/ditto.
As in the implementation there, steps of the global and local models are done in tandem and for the same number of steps.
- Parameters:
input (TorchInputType) – input tensor to be run through both the global and local models. Here,
TorchInputType
is simply an alias for the union oftorch.Tensor
anddict[str, torch.Tensor]
.target (TorchTargetType) – target tensor to be used to compute a loss given each models outputs.
- Returns:
Returns relevant loss values from both the global and local model optimization steps. The prediction dictionary contains predictions indexed a “global” and “local” corresponding to predictions from the global and local Ditto models for metric evaluations.
- Return type:
tuple[TrainingLosses, TorchPredType]
- class MrMtlPersonalizedMixin(*args, **kwargs)[source]¶
Bases:
AdaptiveDriftConstrainedMixin
- __init__(*args, **kwargs)[source]¶
This client implements the MR-MTL algorithm from MR-MTL: On Privacy and Personalization in Cross-Silo Federated Learning. The idea is that we want to train personalized versions of the global model for each client. However, instead of using a separate solver for the global model, as in Ditto, we update the initial global model with aggregated local models on the server-side and use those weights to also constrain the training of a local model. The constraint for this local model is identical to the FedProx loss. The key difference is that the local model is never replaced with aggregated weights. It is always local.
NOTE: lambda, the drift loss weight, is initially set and potentially adapted by the server akin to the heuristic suggested in the original FedProx paper. Adaptation is optional and can be disabled in the corresponding strategy used by the server
- compute_training_loss(preds, features, target)[source]¶
Computes training losses given predictions of the modes and ground truth data. We add to vanilla loss function by including Mean Regularized (MR) penalty loss which is the \(\ell^2\) inner product between the initial global model weights and weights of the current model.
- Parameters:
preds (TorchPredType) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.
features (
dict
[str
,Tensor
]) – (TorchFeatureType): Feature(s) of the model(s) indexed by name.target (
Tensor
|dict
[str
,Tensor
]) – (TorchTargetType): Ground truth data to evaluate predictions against.
- Returns:
An instance of
TrainingLosses
containing backward loss and additional losses indexed by name. Additional losses includes each loss component of the total loss.- Return type:
- get_global_model(config)[source]¶
Returns the global model on client setup to be used as a constraint for the local model during training.
The global model should be the same architecture as the local model so we reuse the
get_model
call. We explicitly send the model to the desired device. This is idempotent.- Parameters:
config (Config) – The config from the server.
- Returns:
The PyTorch model serving as the global model for Ditto
- Return type:
nn.Module
- get_optimizer(config)[source]¶
Implementing get_optimizer as a hook to set initial global model if not already set.
- set_parameters(parameters, config, fitting_round)[source]¶
The parameters being passed are to be routed to the initial global model to be used in a penalty term in training the local model. Despite the usual FL setup, we actually never pass the aggregated model to the LOCAL model. Instead, we use the aggregated model to form the MR-MTL penalty term.
NOTE: In MR-MTL, unlike Ditto, the local model weights are not synced across clients to the initial global model, even in the FIRST ROUND.
- Parameters:
parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model. It will also contain a penalty weight from the server at each round (possibly adapted)
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. Not used here.
- Return type:
- setup_client(config)[source]¶
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True. In this class, this function simply adds the additional step of setting up the global model.
- Parameters:
config (Config) – The config from the server.
- Return type:
- class PersonalizedMode(value)[source]¶
Bases:
Enum
An enumeration.
- DITTO = 'ditto'¶
- MR_MTL = 'mr_mtl'¶
- make_it_personal(client_base_type, mode)[source]¶
A mixed class factory for converting basic clients to personalized versions.
- Return type:
Submodules¶
- fl4health.mixins.personalized.ditto module
DittoPersonalizedMixin
DittoPersonalizedMixin.__init__()
DittoPersonalizedMixin.compute_evaluation_loss()
DittoPersonalizedMixin.get_global_model()
DittoPersonalizedMixin.get_optimizer()
DittoPersonalizedMixin.get_parameters()
DittoPersonalizedMixin.initialize_all_model_weights()
DittoPersonalizedMixin.optimizer_keys
DittoPersonalizedMixin.safe_global_model()
DittoPersonalizedMixin.set_initial_global_tensors()
DittoPersonalizedMixin.set_optimizer()
DittoPersonalizedMixin.set_parameters()
DittoPersonalizedMixin.setup_client()
DittoPersonalizedMixin.train_step()
DittoPersonalizedMixin.update_before_train()
DittoPersonalizedMixin.val_step()
DittoPersonalizedMixin.validate()
DittoPersonalizedProtocol
DittoPersonalizedProtocol.criterion
DittoPersonalizedProtocol.device
DittoPersonalizedProtocol.drift_penalty_tensors
DittoPersonalizedProtocol.drift_penalty_weight
DittoPersonalizedProtocol.get_global_model()
DittoPersonalizedProtocol.global_model
DittoPersonalizedProtocol.initialized
DittoPersonalizedProtocol.loss_for_adaptation
DittoPersonalizedProtocol.model
DittoPersonalizedProtocol.optimizer_keys
DittoPersonalizedProtocol.optimizers
DittoPersonalizedProtocol.parameter_exchanger
DittoPersonalizedProtocol.penalty_loss_function
DittoPersonalizedProtocol.safe_global_model()
DittoPersonalizedProtocol.set_initial_global_tensors()
DittoPersonalizedProtocol.test_loader
DittoPersonalizedProtocol.train_loader
DittoPersonalizedProtocol.val_loader
- fl4health.mixins.personalized.mr_mtl module
MrMtlPersonalizedMixin
MrMtlPersonalizedMixin.__init__()
MrMtlPersonalizedMixin.compute_training_loss()
MrMtlPersonalizedMixin.get_global_model()
MrMtlPersonalizedMixin.get_optimizer()
MrMtlPersonalizedMixin.set_parameters()
MrMtlPersonalizedMixin.setup_client()
MrMtlPersonalizedMixin.update_before_train()
MrMtlPersonalizedMixin.validate()
MrMtlPersonalizedProtocol
MrMtlPersonalizedProtocol.criterion
MrMtlPersonalizedProtocol.device
MrMtlPersonalizedProtocol.drift_penalty_tensors
MrMtlPersonalizedProtocol.drift_penalty_weight
MrMtlPersonalizedProtocol.get_global_model()
MrMtlPersonalizedProtocol.initial_global_model
MrMtlPersonalizedProtocol.initial_global_tensors
MrMtlPersonalizedProtocol.initialized
MrMtlPersonalizedProtocol.loss_for_adaptation
MrMtlPersonalizedProtocol.model
MrMtlPersonalizedProtocol.optimizers
MrMtlPersonalizedProtocol.parameter_exchanger
MrMtlPersonalizedProtocol.penalty_loss_function
MrMtlPersonalizedProtocol.test_loader
MrMtlPersonalizedProtocol.train_loader
MrMtlPersonalizedProtocol.val_loader
- fl4health.mixins.personalized.utils module