fl4health.model_bases.sequential_split_models module¶
- class SequentiallySplitExchangeBaseModel(base_module, head_module, flatten_features=False)[source]¶
Bases:
SequentiallySplitModel,PartialLayerExchangeModelThis model is a specific type of sequentially split model, where we specify the layers to be exchanged as being those belonging to the
base_module.
- class SequentiallySplitModel(base_module, head_module, flatten_features=False)[source]¶
Bases:
Module- __init__(base_module, head_module, flatten_features=False)[source]¶
These models are split into two sequential stages. The first is
base_module, used as a feature extractor. The second is thehead_module, used as a classifier. Features are extracted from thebase_moduleand stored for later use, if required.- Parameters:
base_module (nn.Module) – Feature extraction module.
head_module (nn.Module) – Classification (or other type) of head that acts on the output from the base module.
flatten_features (bool, optional) – Whether the feature tensor shapes are to be preserved (false) or if they should be flattened to be of shape (
batch_size, -1). Flattening may be necessary when using certain loss functions, as in MOON, for example. Defaults to False.
- features_forward(input)[source]¶
Run a forward pass using the
base_moduleonly, returning the features extracted from it.- Parameters:
input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (
batch_size, *)- Returns:
Returns the potentially flatten features tensor from the base module.
- Return type:
torch.Tensor
- forward(input)[source]¶
Run a forward pass using the sequentially split modules
base_module->head_module. Features from thebase_moduleare stored either in their original shapes are flattened to be of shape (batch_size, -1) depending onself.flatten_features.- Parameters:
input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (
batch_size, *)- Returns:
Return the prediction dictionary and a features dictionaries representing the output of the
base_moduleeither in the standard tensor shape or flattened, to be compatible, for example, with MOON contrastive losses.- Return type:
- sequential_forward(input)[source]¶
Run a forward pass using the sequentially split modules
base_module->head_module.- Parameters:
input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (
batch_size, *)- Returns:
Returns the predictions and features tensor from the sequential forward.
- Return type:
tuple[torch.Tensor, torch.Tensor]