fl4health.losses.perfcl_loss module

class PerFclLoss(device, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]

Bases: Module

__init__(device, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]

Loss function for local model training with the PerFCL Method: https://ieeexplore.ieee.org/document/10020518/ It is essentially a combination of two separate MOON contrastive losses.

Parameters:
  • device (torch.device) – Device onto which this loss should be transferred.

  • global_feature_loss_temperature (float, optional) – Temperature for the contrastive loss associated with the global features. Defaults to 0.5.

  • local_feature_loss_temperature (float, optional) – Temperature for the contrastive loss associated with the local features. Defaults to 0.5.

forward(local_features, old_local_features, global_features, old_global_features, initial_global_features)[source]

PerFCL loss implemented based on https://www.sciencedirect.com/science/article/pii/S0031320323002078.

This paper introduced two contrastive loss functions:

  • First one aims to enhance the similarity between the current global features (\(z_s\)) and aggregated global features (\(z_g\)) (saved at the start of client-side training) as positive pairs while reducing the similarity between the current global features (\(z_s\)) and old global features (\(\hat{z}_s\)) from the end of the previous client-side training as negative pairs.

  • Second one aims to enhance the similarity between the current local features (\(z_p\)) and old local features (\(\hat{z}_p\)) from the end of the previous client-side training as positive pairs while reducing the similarity between the current local features (\(z_p\)) and aggregated global features (\(z_g\)) (saved at the start of client-side training) as negative pairs.

Parameters:
  • local_features (torch.Tensor) – Features produced by the local feature extractor of the model during the client-side training. Denoted as \(z_p\) in the original paper. Shape (batch_size, n_features)

  • old_local_features (torch.Tensor) – Features produced by the FINAL local feature extractor of the model from the PREVIOUS server round. Denoted as \(\hat{z}_p\) in the original paper. Shape (batch_size, n_features)

  • global_features (torch.Tensor) – Features produced by the global feature extractor of the model during the client-side training. Denoted as \(z_s\) in the original paper. Shape (batch_size, n_features)

  • old_global_features (torch.Tensor) – Features produced by the FINAL global feature extractor of the model from the PREVIOUS server round. Denoted as \(\hat{z}_s\) in the original paper. Shape (batch_size, n_features)

  • initial_global_features (torch.Tensor) – Features produced by the INITIAL global feature extractor of the model at the start of client-side training. This feature extractor is the AGGREGATED weights across clients. Shape (batch_size, n_features)

Returns:

Tuple containing the two components of the PerFCL loss function to be weighted and summed. The first tensor corresponds to the global feature loss, the second is associated with the local feature loss.

Return type:

tuple[torch.Tensor, torch.Tensor]