fl4health.model_bases.moon_base module

class MoonModel(base_module, head_module, projection_module=None)[source]

Bases: SequentiallySplitModel

__init__(base_module, head_module, projection_module=None)[source]

A MOON Model is a specific type of sequentially split model, where one may specify an optional projection module to be used for feature manipulation. The model always stores the features produced by the base module as they will be used in contrastive loss function calculations. These features are, also, always flattened to be compatible with such losses.

Parameters:
  • base_module (nn.Module) – Feature extractor component of the model

  • head_module (nn.Module) – Classification (or other type) of head used by the model

  • projection_module (nn.Module | None, optional) – An optional module for manipulating the features before they are passed to the head_module. Defaults to None.

sequential_forward(input)[source]

Overriding the sequential forward of the SequentiallySplitModel parent to allow for the injection of a projection module into the forward pass. The remainder of the functionality stays the same. That is, We run a forward pass using the sequentially split modules base_module -> head_module.

Parameters:

input (torch.Tensor) – Input to the model forward pass. Expected to be of shape (batch_size, *)

Returns:

Returns the predictions and features tensor from the sequential forward

Return type:

tuple[torch.Tensor, torch.Tensor]