fl4health.model_bases.fenda_base module

class FendaModel(local_module, global_module, model_head)[source]

Bases: PartialLayerExchangeModel, ParallelSplitModel

__init__(local_module, global_module, model_head)[source]

This is the base model to be used when implementing FENDA-FL models and training. A FENDA model is essentially a parallel split model (i.e. it has two parallel feature extractors), where only one feature extractor is exchanged with the server (the global_module) while the other remains local to the client itself.

Parameters:
  • local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server

  • global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules

  • model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.

layers_to_exchange()[source]
Return type:

list[str]

class FendaModelWithFeatureState(local_module, global_module, model_head, flatten_features=False)[source]

Bases: FendaModel

__init__(local_module, global_module, model_head, flatten_features=False)[source]

This is the base model to be used when implementing FENDA-FL models and training when extraction and storage of the latent features produced by each of the parallel feature extractors is required/desired. This is a FENDA model, but the feature space outputs are guaranteed to be stored with the keys “local_features” and “global_features” along with the predictions. The user also has the option to “flatten” these features to be of shape batch_size x all features

Parameters:
  • local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server

  • global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules

  • model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.

  • flatten_features (bool, optional) – Whether the output features should be flattened to have shape batch_size x all features. Defaults to False.

forward(input)[source]

Mapping input through the FENDA model local and global feature extractors and the classification head

Parameters:

input (torch.Tensor) – input is expected to be of shape (batch_size, *)

Returns:

Tuple of predictions and feature maps. FENDA predictions are simply stored under the key “prediction.” The features for the local and global feature extraction modules are stored under keys “local_features” and “global_features,” respectively.

Return type:

tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]