import torch
import torch.nn as nn
from fl4health.model_bases.partial_layer_exchange_model import PartialLayerExchangeModel
[docs]
class SequentiallySplitModel(nn.Module):
[docs]
def __init__(self, base_module: nn.Module, head_module: nn.Module, flatten_features: bool = False) -> None:
"""
These models are split into two sequential stages. The first is base_module, used as a feature extractor.
The second is the head_module, used as a classifier. Features are extracted from the base_module and stored
for later use, if required
Args:
base_module (nn.Module): Feature extraction module
head_module (nn.Module): Classification (or other type) of head that acts on the output from the base
module
flatten_features (bool, optional): Whether the feature tensor shapes are to be preserved (false) or if
they should be flattened to be of shape (batch_size, -1). Flattening may be necessary when using
certain loss functions, as in MOON, for example. Defaults to False.
"""
super().__init__()
self.base_module = base_module
self.head_module = head_module
self.flatten_features = flatten_features
def _flatten_features(self, features: torch.Tensor) -> torch.Tensor:
"""
The features tensor is flattened to be of shape are flattened to be of shape (batch_size, -1). It is expected
that the feature tensor is BATCH FIRST
Args:
features (torch.Tensor): Features tensor to be flattened. It is assumed that this tensor is BATCH FIRST.
Returns:
torch.Tensor: Flattened feature tensor of shape (batch_size, -1)
"""
return features.reshape(len(features), -1)
[docs]
def sequential_forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a forward pass using the sequentially split modules base_module -> head_module.
Args:
input (torch.Tensor): Input to the model forward pass. Expected to be of shape (batch_size, *)
Returns:
tuple[torch.Tensor, torch.Tensor]: Returns the predictions and features tensor from the sequential forward
"""
features = self.base_module.forward(input)
predictions = self.head_module.forward(features)
return predictions, features
[docs]
def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
"""
Run a forward pass using the sequentially split modules base_module -> head_module. Features from the
base_module are stored either in their original shapes are flattened to be of shape (batch_size, -1) depending
on self.flatten_features
Args:
input (torch.Tensor): Input to the model forward pass. Expected to be of shape (batch_size, *)
Returns:
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: Dictionaries of predictions and features
"""
predictions, features = self.sequential_forward(input)
predictions_dict = {"prediction": predictions}
features_dict = (
{"features": self._flatten_features(features)} if self.flatten_features else {"features": features}
)
# Return the prediction dictionary and a features dictionaries representing the output of the base_module
# either in the standard tensor shape or flattened, to be compatible, for example, with MOON contrastive
# losses.
return predictions_dict, features_dict
[docs]
class SequentiallySplitExchangeBaseModel(SequentiallySplitModel, PartialLayerExchangeModel):
"""
This model is a specific type of sequentially split model, where we specify the layers to be exchanged as being
those belonging to the base_module.
"""
[docs]
def layers_to_exchange(self) -> list[str]:
"""
Names of the layers of the model to be exchanged with the server. For these models, we only exchange layers
associated with the base_model.
Returns:
list[str]: The names of the layers to be exchanged with the server. This is used by the FixedLayerExchanger
class
"""
return [layer_name for layer_name in self.state_dict().keys() if layer_name.startswith("base_module.")]