fl4health.model_bases.apfl_base module¶
- class ApflModule(model, adaptive_alpha=True, alpha=0.5, alpha_lr=0.01)[source]¶
Bases:
PartialLayerExchangeModel
- __init__(model, adaptive_alpha=True, alpha=0.5, alpha_lr=0.01)[source]¶
Defines a model compatible with the APFL approach.
- Parameters:
model (nn.Module) – The underlying model architecture to be optimized. A twin of this model will be created to initialize a local and global version of this architecture.
adaptive_alpha (bool, optional) – Whether or not the mixing parameter \(\alpha\) will be adapted during training. Predictions of the local and global models are combined using \(\alpha\) to provide a final prediction. Defaults to True.
alpha (float, optional) – The initial value for the mixing parameter \(\alpha\). Defaults to 0.5.
alpha_lr (float, optional) – The learning rate to be applied when adaptive \(\alpha\) during training. If
adaptive_alpha
is False, then this parameter does nothing. Defaults to 0.01.
- forward(input)[source]¶
Forward function for the full APFL model. This includes mixing of the global and local model predictions using \(\alpha\). The predictions are combined as
\[\alpha \cdot \text{local_logits} + (1.0 - \alpha) \cdot \text{global_logits}\]- Parameters:
input (torch.Tensor) – Input tensor to be run through both the local and global models
- Returns:
Final prediction after mixing predictions produced by the local and global models. This dictionary stores these predictions under the key “personal” while the local and global model predictions are stored under the keys “global” and “local.”
- Return type:
- global_forward(input)[source]¶
Forward function that runs the input tensor through the GLOBAL model only
- Parameters:
input (torch.Tensor) – tensor to be run through the global model
- Returns:
output from the global model only.
- Return type:
torch.Tensor
- layers_to_exchange()[source]¶
Specifies the model layers to be exchanged with the server. These are a fixed set of layers exchanged every round. For APFL, these are any layers associated with the
global_model
. That is, none of the parameters of the local model are aggregated on the server side, nor is \(\alpha\).
- local_forward(input)[source]¶
Forward function that runs the input tensor through the LOCAL model only
- Parameters:
input (torch.Tensor) – tensor to be run through the local model
- Returns:
output from the local model only.
- Return type:
torch.Tensor