fl4health.model_bases.sequential_split_models module¶
- class SequentiallySplitExchangeBaseModel(base_module, head_module, flatten_features=False)[source]¶
Bases:
SequentiallySplitModel
,PartialLayerExchangeModel
This model is a specific type of sequentially split model, where we specify the layers to be exchanged as being those belonging to the
base_module
.
- class SequentiallySplitModel(base_module, head_module, flatten_features=False)[source]¶
Bases:
Module
- __init__(base_module, head_module, flatten_features=False)[source]¶
These models are split into two sequential stages. The first is
base_module
, used as a feature extractor. The second is thehead_module
, used as a classifier. Features are extracted from thebase_module
and stored for later use, if required- Parameters:
base_module (nn.Module) – Feature extraction module
head_module (nn.Module) – Classification (or other type) of head that acts on the output from the base module
flatten_features (bool, optional) – Whether the feature tensor shapes are to be preserved (false) or if they should be flattened to be of shape (
batch_size
, -1). Flattening may be necessary when using certain loss functions, as in MOON, for example. Defaults to False.
- forward(input)[source]¶
Run a forward pass using the sequentially split modules
base_module
->head_module
. Features from the base_module are stored either in their original shapes are flattened to be of shape (batch_size
, -1) depending onself.flatten_features
- Parameters:
input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (
batch_size
, *)- Returns:
Return the prediction dictionary and a features dictionaries representing the output of the
base_module
either in the standard tensor shape or flattened, to be compatible, for example, with MOON contrastive losses.- Return type:
- sequential_forward(input)[source]¶
Run a forward pass using the sequentially split modules
base_module
->head_module
.- Parameters:
input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (
batch_size
, *)- Returns:
Returns the predictions and features tensor from the sequential forward
- Return type:
tuple[torch.Tensor, torch.Tensor]