fl4health.model_bases.fenda_base module

class FendaModel(local_module, global_module, model_head)[source]

Bases: PartialLayerExchangeModel, ParallelSplitModel

__init__(local_module, global_module, model_head)[source]

This is the base model to be used when implementing FENDA-FL models and training. A FENDA model is essentially a parallel split model (i.e. it has two parallel feature extractors), where only one feature extractor is exchanged with the server (the global_module) while the other remains local to the client itself.

Parameters:
  • local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server

  • global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules

  • model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.

layers_to_exchange()[source]
Return type:

list[str]

class FendaModelWithFeatureState(local_module, global_module, model_head, flatten_features=False)[source]

Bases: FendaModel

__init__(local_module, global_module, model_head, flatten_features=False)[source]

This is the base model to be used when implementing FENDA-FL models and training when extraction and storage of the latent features produced by each of the parallel feature extractors is required/desired. This is a FENDA model, but the feature space outputs are guaranteed to be stored with the keys ‘local_features’ and ‘global_features’ along with the predictions. The user also has the option to “flatten” these features to be of shape batch_size x all features

Parameters:
  • local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server

  • global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules

  • model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.

  • flatten_features (bool, optional) – Whether the output features should be flattened to have shape batch_size x all features. Defaults to False.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tuple[dict[str, Tensor], dict[str, Tensor]]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.