fl4health.model_bases.apfl_base module

class ApflModule(model, adaptive_alpha=True, alpha=0.5, alpha_lr=0.01)[source]

Bases: PartialLayerExchangeModel

__init__(model, adaptive_alpha=True, alpha=0.5, alpha_lr=0.01)[source]

Defines a model compatible with the APFL approach.

Parameters:
  • model (nn.Module) – The underlying model architecture to be optimized. A twin of this model will be created to initialize a local and global version of this architecture.

  • adaptive_alpha (bool, optional) – Whether or not the mixing parameter \(\alpha\) will be adapted during training. Predictions of the local and global models are combined using \(\alpha\) to provide a final prediction. Defaults to True.

  • alpha (float, optional) – The initial value for the mixing parameter \(\alpha\). Defaults to 0.5.

  • alpha_lr (float, optional) – The learning rate to be applied when adaptive \(\alpha\) during training. If adaptive_alpha is False, then this parameter does nothing. Defaults to 0.01.

forward(input)[source]

Forward function for the full APFL model. This includes mixing of the global and local model predictions using \(\alpha\). The predictions are combined as

\[\alpha \cdot \text{local_logits} + (1.0 - \alpha) \cdot \text{global_logits}\]
Parameters:

input (torch.Tensor) – Input tensor to be run through both the local and global models

Returns:

Final prediction after mixing predictions produced by the local and global models. This dictionary stores these predictions under the key “personal” while the local and global model predictions are stored under the keys “global” and “local.”

Return type:

dict[str, torch.Tensor]

global_forward(input)[source]

Forward function that runs the input tensor through the GLOBAL model only

Parameters:

input (torch.Tensor) – tensor to be run through the global model

Returns:

output from the global model only.

Return type:

torch.Tensor

layers_to_exchange()[source]

Specifies the model layers to be exchanged with the server. These are a fixed set of layers exchanged every round. For APFL, these are any layers associated with the global_model. That is, none of the parameters of the local model are aggregated on the server side, nor is \(\alpha\).

Returns:

Names of layers associated with the global model. These correspond to the layer names in the state dictionary of this entire module.

Return type:

list[str]

local_forward(input)[source]

Forward function that runs the input tensor through the LOCAL model only

Parameters:

input (torch.Tensor) – tensor to be run through the local model

Returns:

output from the local model only.

Return type:

torch.Tensor

update_alpha()[source]

Updates to mixture parameter follow original implementation:

https://github.com/MLOPTPSU/FedTorch/blob/ab8068dbc96804a5c1a8b898fd115175cfebfe75/fedtorch/comms/utils/flow_utils.py#L240

Return type:

None