fl4health.model_bases.sequential_split_models module¶
- class SequentiallySplitExchangeBaseModel(base_module, head_module, flatten_features=False)[source]¶
Bases:
SequentiallySplitModel
,PartialLayerExchangeModel
This model is a specific type of sequentially split model, where we specify the layers to be exchanged as being those belonging to the base_module.
- class SequentiallySplitModel(base_module, head_module, flatten_features=False)[source]¶
Bases:
Module
- __init__(base_module, head_module, flatten_features=False)[source]¶
These models are split into two sequential stages. The first is base_module, used as a feature extractor. The second is the head_module, used as a classifier. Features are extracted from the base_module and stored for later use, if required
- Parameters:
base_module (nn.Module) – Feature extraction module
head_module (nn.Module) – Classification (or other type) of head that acts on the output from the base module
flatten_features (bool, optional) – Whether the feature tensor shapes are to be preserved (false) or if they should be flattened to be of shape (batch_size, -1). Flattening may be necessary when using certain loss functions, as in MOON, for example. Defaults to False.
- forward(input)[source]¶
Run a forward pass using the sequentially split modules base_module -> head_module. Features from the base_module are stored either in their original shapes are flattened to be of shape (batch_size, -1) depending on self.flatten_features