fl4health.losses.perfcl_loss module¶
- class PerFclLoss(device, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]¶
Bases:
Module
- forward(local_features, old_local_features, global_features, old_global_features, initial_global_features)[source]¶
PerFCL loss implemented based on https://www.sciencedirect.com/science/article/pii/S0031320323002078. This paper introduced two contrastive loss functions:
- 1 - First one aims to enhance the similarity between the current global features (z_s) and aggregated
global features (z_g) (saved at the start of client-side training) as positive pairs while reducing the similarity between the current global features (z_s) and old global features (hat{z}_s) from the end of the previous client-side training as negative pairs.
- 2- Second one aims to enhance the similarity between the current local features (z_p) and old local
features (hat{z}_p) from the end of the previous client-side training as positive pairs while reducing the similarity between the current local features (z_p) and aggregated global features (z_g) (saved at the start of client-side training) as negative pairs.
- Parameters:
local_features (torch.Tensor) – Features produced by the local feature extractor of the model during the client-side training. Denoted as z_p in the original paper. Shape (batch_size, n_features)
old_local_features (torch.Tensor) – Features produced by the FINAL local feature extractor of the model from the PREVIOUS server round. Denoted as hat{z}_p in the original paper. Shape (batch_size, n_features)
global_features (torch.Tensor) – Features produced by the global feature extractor of the model during the client-side training. Denoted as z_s in the original paper. Shape (batch_size, n_features)
old_global_features (torch.Tensor) – Features produced by the FINAL global feature extractor of the model from the PREVIOUS server round. Denoted as hat{z}_s in the original paper. Shape (batch_size, n_features)
initial_global_features (torch.Tensor) – Features produced by the INITIAL global feature extractor of the model at the start of client-side training. This feature extractor is the AGGREGATED weights across clients. Shape (batch_size, n_features)
- Returns:
- Tuple containing the two components of the PerFCL loss function to
be weighted and summed. The first tensor corresponds to the global feature loss, the second is associated with the local feature loss.
- Return type:
tuple[torch.Tensor, torch.Tensor]