fl4health.model_bases.sequential_split_models module

class SequentiallySplitExchangeBaseModel(base_module, head_module, flatten_features=False)[source]

Bases: SequentiallySplitModel, PartialLayerExchangeModel

This model is a specific type of sequentially split model, where we specify the layers to be exchanged as being those belonging to the base_module.

layers_to_exchange()[source]

Names of the layers of the model to be exchanged with the server. For these models, we only exchange layers associated with the base_model.

Returns:

The names of the layers to be exchanged with the server. This is used by the FixedLayerExchanger class

Return type:

list[str]

class SequentiallySplitModel(base_module, head_module, flatten_features=False)[source]

Bases: Module

__init__(base_module, head_module, flatten_features=False)[source]

These models are split into two sequential stages. The first is base_module, used as a feature extractor. The second is the head_module, used as a classifier. Features are extracted from the base_module and stored for later use, if required

Parameters:
  • base_module (nn.Module) – Feature extraction module

  • head_module (nn.Module) – Classification (or other type) of head that acts on the output from the base module

  • flatten_features (bool, optional) – Whether the feature tensor shapes are to be preserved (false) or if they should be flattened to be of shape (batch_size, -1). Flattening may be necessary when using certain loss functions, as in MOON, for example. Defaults to False.

forward(input)[source]

Run a forward pass using the sequentially split modules base_module -> head_module. Features from the base_module are stored either in their original shapes are flattened to be of shape (batch_size, -1) depending on self.flatten_features

Parameters:

input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (batch_size, *)

Returns:

Dictionaries of predictions and features

Return type:

tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]

sequential_forward(input)[source]

Run a forward pass using the sequentially split modules base_module -> head_module.

Parameters:

input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (batch_size, *)

Returns:

Returns the predictions and features tensor from the sequential forward

Return type:

tuple[torch.Tensor, torch.Tensor]