fl4health.model_bases.perfcl_base module¶
- class PerFclModel(local_module, global_module, model_head)[source]¶
Bases:
PartialLayerExchangeModel
,ParallelSplitModel
- __init__(local_module, global_module, model_head)[source]¶
Model to be used by PerFCL clients to train models with the PerFCL approach. These models are of type
ParallelSplitModel
and have distinct feature extractors. One of the feature extractors is exchanged with the server and aggregated while the other remains local. Each of the extractors produces latent features which are flattened and stored with the keys “local_features” and “global_features” along with the predictions.- Parameters:
local_module (nn.Module) – Feature extraction module that is NOT exchanged with the server.
global_module (nn.Module) – Feature extraction module that is exchanged with the server and aggregated with other client modules
model_head (ParallelSplitHeadModule) – The model head that takes the output features from both the local and global modules to produce a prediction.
- forward(input)[source]¶
Mapping input through the PerFCL model local and global feature extractors and the classification head
- Parameters:
input (torch.Tensor) – input is expected to be of shape (
batch_size
, *)- Returns:
Tuple of predictions and feature maps. PerFCL predictions are simply stored under the key “prediction.” The features for the local and global feature extraction modules are stored under keys “local_features” and “global_features,” respectively.
- Return type: