fl4health.losses.fenda_loss_config module

class ConstrainedFendaLossContainer(perfcl_loss_config, cosine_similarity_loss_config, contrastive_loss_config)[source]

Bases: object

__init__(perfcl_loss_config, cosine_similarity_loss_config, contrastive_loss_config)[source]

Container to gather all of the possible loss functions used in constrained FENDA model optimization

Parameters:
compute_contrastive_loss(features, positive_pairs, negative_pairs)[source]

Compute the contrastive loss, if it exists, using the configuration

Parameters:
  • features (torch.Tensor) – features from the model

  • positive_pairs (torch.Tensor) – positive pair features to compare to

  • negative_pairs (torch.Tensor) – negative pair features to compare to.

Returns:

loss function

Return type:

torch.Tensor

compute_cosine_similarity_loss(first_features, second_features)[source]

Compute the cosine loss, if it exists, using the configuration

Parameters:
  • first_features (torch.Tensor) – first set of features in the cosine comparison

  • second_features (torch.Tensor) – second set of features in the cosine comparison

Returns:

cosine similarity loss between the provided features.

Return type:

torch.Tensor

compute_perfcl_loss(local_features, old_local_features, global_features, old_global_features, initial_global_features)[source]

Compute the PerFCL loss, if it exists, using the configuration

Parameters:
  • local_features (torch.Tensor) – See PerFCL loss documentation

  • old_local_features (torch.Tensor) – See PerFCL loss documentation

  • global_features (torch.Tensor) – See PerFCL loss documentation

  • old_global_features (torch.Tensor) – See PerFCL loss documentation

  • initial_global_features (torch.Tensor) – See PerFCL loss documentation

Returns:

PerFCL loss based on the input values

Return type:

tuple[torch.Tensor, torch.Tensor]

has_contrastive_loss()[source]
Return type:

bool

has_cosine_similarity_loss()[source]
Return type:

bool

has_perfcl_loss()[source]
Return type:

bool

class CosineSimilarityLossContainer(device, cos_sim_loss_weight)[source]

Bases: object

__init__(device, cos_sim_loss_weight)[source]

Container to hold the different pieces associated with cosine similarity.

Parameters:
  • device (torch.device) – Device to which the loss will be sent and computed on.

  • cos_sim_loss_weight (float) – Weight associated with the cosine loss function in optimization.

class MoonContrastiveLossContainer(device, contrastive_loss_weight, temperature=0.5)[source]

Bases: object

__init__(device, contrastive_loss_weight, temperature=0.5)[source]

_summary_

Parameters:
  • device (torch.device) – Device to which the loss will be sent and computed on.

  • contrastive_loss_weight (float) – Weight associated with the contrastive loss function in optimization.

  • temperature (float, optional) – Temperature parameter on the contrastive loss function. Defaults to 0.5.

class PerFclLossContainer(device, global_feature_contrastive_loss_weight, local_feature_contrastive_loss_weight, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]

Bases: object

__init__(device, global_feature_contrastive_loss_weight, local_feature_contrastive_loss_weight, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]

Container to hold the different pieces associated with PerFCL Loss

Parameters:
  • device (torch.device) – Device to which the loss will be sent and computed on.

  • global_feature_contrastive_loss_weight (float) – Weight on the global contrastive loss function.

  • local_feature_contrastive_loss_weight (float) – Weight on the local model contrastive loss function.

  • global_feature_loss_temperature (float, optional) – Temperature parameter on the global contrastive loss function. Defaults to 0.5.

  • local_feature_loss_temperature (float, optional) – Temperature parameter on the local contrastive loss function. Defaults to 0.5.