fl4health.model_bases.perfcl_base module

class PerFclModel(local_module, global_module, model_head)[source]

Bases: PartialLayerExchangeModel, ParallelSplitModel

__init__(local_module, global_module, model_head)[source]

Model to be used by PerFCL clients to train models with the PerFCL approach. These models are of type ParallelSplitModel and have distinct feature extractors. One of the feature extractors is exchanged with the server and aggregated while the other remains local. Each of the extractors produces latent features which are flattened and stored with the keys ‘local_features’ and ‘global_features’ along with the predictions.

Parameters:
  • local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server

  • global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules

  • model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tuple[dict[str, Tensor], dict[str, Tensor]]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

layers_to_exchange()[source]
Return type:

list[str]