Source code for fl4health.model_bases.apfl_base

import copy

import torch
import torch.nn as nn

from fl4health.model_bases.partial_layer_exchange_model import PartialLayerExchangeModel


[docs] class ApflModule(PartialLayerExchangeModel):
[docs] def __init__( self, model: nn.Module, adaptive_alpha: bool = True, alpha: float = 0.5, alpha_lr: float = 0.01, ) -> None: """ Defines a model compatible with the APFL approach. Args: model (nn.Module): The underlying model architecture to be optimized. A twin of this model will be created to initialize a local and global version of this architecture. adaptive_alpha (bool, optional): Whether or not the mixing parameter :math:`\\alpha` will be adapted during training. Predictions of the local and global models are combined using :math:`\\alpha` to provide a final prediction. Defaults to True. alpha (float, optional): The initial value for the mixing parameter :math:`\\alpha`. Defaults to 0.5. alpha_lr (float, optional): The learning rate to be applied when adaptive :math:`\\alpha` during training. If ``adaptive_alpha`` is False, then this parameter does nothing. Defaults to 0.01. """ super().__init__() self.local_model: nn.Module = model self.global_model: nn.Module = copy.deepcopy(model) self.adaptive_alpha = adaptive_alpha self.alpha = alpha self.alpha_lr = alpha_lr
[docs] def global_forward(self, input: torch.Tensor) -> torch.Tensor: """ Forward function that runs the input tensor through the **GLOBAL** model only Args: input (torch.Tensor): tensor to be run through the global model Returns: torch.Tensor: output from the global model only. """ return self.global_model(input)
[docs] def local_forward(self, input: torch.Tensor) -> torch.Tensor: """ Forward function that runs the input tensor through the **LOCAL** model only Args: input (torch.Tensor): tensor to be run through the local model Returns: torch.Tensor: output from the local model only. """ return self.local_model(input)
[docs] def forward(self, input: torch.Tensor) -> dict[str, torch.Tensor]: """ Forward function for the full APFL model. This includes mixing of the global and local model predictions using :math:`\\alpha`. The predictions are combined as .. math:: \\alpha \\cdot \\text{local_logits} + (1.0 - \\alpha) \\cdot \\text{global_logits} Args: input (torch.Tensor): Input tensor to be run through both the local and global models Returns: dict[str, torch.Tensor]: Final prediction after mixing predictions produced by the local and global models. This dictionary stores these predictions under the key "personal" while the local and global model predictions are stored under the keys "global" and "local." """ # Forward return dictionary because APFL has multiple different prediction types global_logits = self.global_forward(input) local_logits = self.local_forward(input) personal_logits = self.alpha * local_logits + (1.0 - self.alpha) * global_logits preds = {"personal": personal_logits, "global": global_logits, "local": local_logits} return preds
[docs] def update_alpha(self) -> None: """ Updates to mixture parameter follow original implementation: https://github.com/MLOPTPSU/FedTorch/blob/ab8068dbc96804a5c1a8b898fd115175cfebfe75/fedtorch/comms/utils/flow_utils.py#L240 """ # noqa # Need to filter out frozen parameters, as they have no grad object local_parameters = [ local_params for local_params in self.local_model.parameters() if local_params.requires_grad ] global_parameters = [ global_params for global_params in self.global_model.parameters() if global_params.requires_grad ] # Accumulate gradient of alpha across layers grad_alpha: float = 0.0 for local_p, global_p in zip(local_parameters, global_parameters): local_grad = local_p.grad global_grad = global_p.grad assert local_grad is not None and global_grad is not None dif = local_p - global_p grad = torch.tensor(self.alpha) * local_grad + torch.tensor(1.0 - self.alpha) * global_grad grad_alpha += torch.mul(dif, grad).sum().detach().cpu().numpy().item() # This update constant of 0.02 is not referenced in the paper # but is present in the official implementation and other ones I have seen # Not sure its function, just adding a number proportional to alpha to the grad # Leaving in for consistency with official implementation grad_alpha += 0.02 * self.alpha alpha = self.alpha - self.alpha_lr * grad_alpha # Clip alpha to be between [0, 1] alpha = max(min(alpha, 1), 0) self.alpha = alpha
[docs] def layers_to_exchange(self) -> list[str]: """ Specifies the model layers to be exchanged with the server. These are a fixed set of layers exchanged every round. For APFL, these are any layers associated with the ``global_model``. That is, none of the parameters of the local model are aggregated on the server side, nor is :math:`\\alpha`. Returns: list[str]: Names of layers associated with the global model. These correspond to the layer names in the state dictionary of this entire module. """ layers_to_exchange: list[str] = [ layer for layer in self.state_dict().keys() if layer.startswith("global_model.") ] return layers_to_exchange