fl4health.model_bases.fenda_base module¶
- class FendaModel(local_module, global_module, model_head)[source]¶
Bases:
PartialLayerExchangeModel
,ParallelSplitModel
- __init__(local_module, global_module, model_head)[source]¶
This is the base model to be used when implementing FENDA-FL models and training. A FENDA model is essentially a parallel split model (i.e. it has two parallel feature extractors), where only one feature extractor is exchanged with the server (the
global_module
) while the other remains local to the client itself.- Parameters:
local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server
global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules
model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.
- class FendaModelWithFeatureState(local_module, global_module, model_head, flatten_features=False)[source]¶
Bases:
FendaModel
- __init__(local_module, global_module, model_head, flatten_features=False)[source]¶
This is the base model to be used when implementing FENDA-FL models and training when extraction and storage of the latent features produced by each of the parallel feature extractors is required/desired. This is a FENDA model, but the feature space outputs are guaranteed to be stored with the keys “local_features” and “global_features” along with the predictions. The user also has the option to “flatten” these features to be of shape
batch_size
x all features- Parameters:
local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server
global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules
model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.
flatten_features (bool, optional) – Whether the output features should be flattened to have shape
batch_size
x all features. Defaults to False.
- forward(input)[source]¶
Mapping input through the FENDA model local and global feature extractors and the classification head
- Parameters:
input (torch.Tensor) – input is expected to be of shape (
batch_size
, *)- Returns:
Tuple of predictions and feature maps. FENDA predictions are simply stored under the key “prediction.” The features for the local and global feature extraction modules are stored under keys “local_features” and “global_features,” respectively.
- Return type: