fl4health.model_bases.gpfl_base module¶
- class CoV(feature_dim)[source]¶
Bases:
Module
- __init__(feature_dim)[source]¶
Taken from the official implementation at : https://github.com/TsingZ0/GPFL/blob/main/system/flcore/servers/servergp.py CoV (Conditional Value) module as described in the GPFL paper. This module consists of two parts. 1) First, uses the provided context tensor to compute two vectors, $gamma$ and $eta$ using
conditional_gamma
andconditional_beta
sub-modules, respectively. In the paper: $[mathbf{gamma_i}, mathbf{eta_i} = ext{CoV}(mathbf{f}_i, cdot, V)]$ 2) Then, applies an affine transformation followed by a ReLU activation to the feature tensors based on the computed $gamma$ and $eta$ vectors. Affine transformation in the paper: $[(mathbf{gamma} + mathbf{1})odot mathbf{f}_i + mathbf{eta}]$ Parameters of the sub-modules (conditional_gamma
andconditional_beta``$
modules) are the main components of this module, and are optimized during the training process.- Parameters:
feature_dim (int) – The dimension of the feature tensor.
- forward(feature_tensor, context)[source]¶
Uses the context tensor to compute gamma and beta vectors. Then, applies a conditional affine transformation to the feature tensor based on the computed gamma and beta vectors.
- Parameters:
feature_tensor (torch.Tensor) – Output of the base feature extractor.
context (torch.Tensor) – The conditional tensor that could be global or personalized.
- Returns:
The transformed feature tensor after applying the conditional affine transformation.
- Return type:
torch.Tensor
- class Gce(feature_dim, num_classes)[source]¶
Bases:
Module
- __init__(feature_dim, num_classes)[source]¶
Taken from the official implementation at : https://github.com/TsingZ0/GPFL/blob/main/system/flcore/servers/servergp.py GCE module as described in the GPFL paper. This module is used as a lookup table of global class embeddings. The size of the embedding matrix (the lookup table) is (num_classes, feature_dim). The goal is to learn and store representative class embeddings.
- forward(feature_tensor, label)[source]¶
Performs a forward pass through the GCE module. It computes the cosine similarity between the feature tensors and the class embeddings, and then computes the log softmax loss based on the provided labels.
- Parameters:
feature_tensor (torch.Tensor) – The global features computed by the CoV module.
label (torch.Tensor) – The true label for the input data, which is used to compute the loss.
- Returns:
Log softmax loss.
- Return type:
torch.Tensor
- class GpflBaseAndHeadModules(base_module, head_module, flatten_features)[source]¶
Bases:
SequentiallySplitExchangeBaseModel
- __init__(base_module, head_module, flatten_features)[source]¶
This module class holds the main components for prediction in the GPFL model. This is mainly used to enable defining one optimizer for the base and head modules.
- Parameters:
base_module (nn.Module) – Base feature extractor module that generates a feature tensor from the input.
head_module (nn.Module) – Head module that takes a personalized feature tensor and produces the final predictions.
flatten_features (bool) – Whether the
base_module
’s output features should be flattened or not.
- forward(input)[source]¶
A wrapper around the default sequential forward pass of the GPFL model base to restrict its usage.
- Parameters:
input (torch.Tensor) – Input to the model forward pass.
- Returns:
Return the prediction dictionary and a features dictionaries.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- class GpflModel(base_module, head_module, feature_dim, num_classes, flatten_features=False)[source]¶
Bases:
PartialLayerExchangeModel
- __init__(base_module, head_module, feature_dim, num_classes, flatten_features=False)[source]¶
GPFL model base as described in the paper “GPFL: Simultaneously Learning Global and Personalized Feature Information for Personalized Federated Learning.” https://arxiv.org/abs/2308.10279 This base module consists of three main sub-modules: the main_module, which consists of a feature extractor and a head module; the GCE (Global Conditional Embedding) module; and the CoV (Conditional Value) module.
- Parameters:
base_module (nn.Module) – Base feature extractor module that generates a feature tensor from the input.
head_module (nn.Module) – Head module that takes a personalized feature tensor and produces the final predictions.
feature_dim (int) – The output dimension of the base feature extractor. This is also the input dimension of the head and CoV modules.
num_classes (int) – This is used to construct the GCE module.
flatten_features (bool, optional) – Whether the
base_module
’s output features should be flattened or not. Defaults to False.
- forward(input, global_conditional_input, personalized_conditional_input)[source]¶
There are two types of forward passes in this model base. The first is the forward pass preformed during training. During training: 1) Input is passed through the base feature extractor. 2) Then the CoV module maps the extracted features into two feature tensors corresponding to local and global
features. The CoV module requires
global_conditional_input
andpersonalized_conditional_input
tensors, which are used to condition the output of the CoV module. These tensors are computed in clients at the beginning of each round.The
local_features
are fed into thehead_module
to produce class predictions.The
global_conditional_input
is used to compute the global features, and theseglobal_features
to be used in loss calculations and are returned only during training.
The second type of forward pass happens during evaluation. For evaluation:
Input is passed through the base feature extractor.
local_features
are generated by the CoV module.These local features are passed through the head module to produce the final predictions.
- Parameters:
input (torch.Tensor) – Input tensor to be fed into the feature extractor.
global_conditional_input (torch.Tensor) – The conditional input tensor used by the CoV module to generate the global features.
personalized_conditional_input (torch.Tensor) – The conditional input tensor used by the CoV module to generate the local features.
- Returns:
- A tuple in which the first element
contains a dictionary of predictions and the second element contains intermediate features indexed by name.
- Return type: