fl4health.model_bases.fenda_base module¶
- class FendaModel(local_module, global_module, model_head)[source]¶
Bases:
PartialLayerExchangeModel
,ParallelSplitModel
- __init__(local_module, global_module, model_head)[source]¶
This is the base model to be used when implementing FENDA-FL models and training. A FENDA model is essentially a parallel split model (i.e. it has two parallel feature extractors), where only one feature extractor is exchanged with the server (the global_module) while the other remains local to the client itself.
- Parameters:
local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server
global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules
model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.
- class FendaModelWithFeatureState(local_module, global_module, model_head, flatten_features=False)[source]¶
Bases:
FendaModel
- __init__(local_module, global_module, model_head, flatten_features=False)[source]¶
This is the base model to be used when implementing FENDA-FL models and training when extraction and storage of the latent features produced by each of the parallel feature extractors is required/desired. This is a FENDA model, but the feature space outputs are guaranteed to be stored with the keys ‘local_features’ and ‘global_features’ along with the predictions. The user also has the option to “flatten” these features to be of shape batch_size x all features
- Parameters:
local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server
global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules
model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.
flatten_features (bool, optional) – Whether the output features should be flattened to have shape batch_size x all features. Defaults to False.
- forward(input)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
tuple
[dict
[str
,Tensor
],dict
[str
,Tensor
]]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.