fl4health.losses.fenda_loss_config module¶
- class ConstrainedFendaLossContainer(perfcl_loss_config, cosine_similarity_loss_config, contrastive_loss_config)[source]¶
Bases:
object
- __init__(perfcl_loss_config, cosine_similarity_loss_config, contrastive_loss_config)[source]¶
Container to gather all of the possible loss functions used in constrained FENDA model optimization
- Parameters:
perfcl_loss_config (PerFclLossContainer | None) – PerFCL loss container. If none, the loss isn not used.
cosine_similarity_loss_config (CosineSimilarityLossContainer | None) – Cosine similarity loss container. If none the loss is not used.
contrastive_loss_config (MoonContrastiveLossContainer | None) – Contrastive loss container. If none, the loss is not used.
- compute_contrastive_loss(features, positive_pairs, negative_pairs)[source]¶
Compute the contrastive loss, if it exists, using the configuration
- Parameters:
features (torch.Tensor) – features from the model
positive_pairs (torch.Tensor) – positive pair features to compare to
negative_pairs (torch.Tensor) – negative pair features to compare to.
- Returns:
loss function
- Return type:
torch.Tensor
- compute_cosine_similarity_loss(first_features, second_features)[source]¶
Compute the cosine loss, if it exists, using the configuration
- Parameters:
first_features (torch.Tensor) – first set of features in the cosine comparison
second_features (torch.Tensor) – second set of features in the cosine comparison
- Returns:
cosine similarity loss between the provided features.
- Return type:
torch.Tensor
- compute_perfcl_loss(local_features, old_local_features, global_features, old_global_features, initial_global_features)[source]¶
Compute the PerFCL loss, if it exists, using the configuration
- Parameters:
local_features (torch.Tensor) – See PerFCL loss documentation
old_local_features (torch.Tensor) – See PerFCL loss documentation
global_features (torch.Tensor) – See PerFCL loss documentation
old_global_features (torch.Tensor) – See PerFCL loss documentation
initial_global_features (torch.Tensor) – See PerFCL loss documentation
- Returns:
PerFCL loss based on the input values
- Return type:
tuple[torch.Tensor, torch.Tensor]
- class CosineSimilarityLossContainer(device, cos_sim_loss_weight)[source]¶
Bases:
object
- __init__(device, cos_sim_loss_weight)[source]¶
Container to hold the different pieces associated with cosine similarity.
- Parameters:
device (torch.device) – Device to which the loss will be sent and computed on.
cos_sim_loss_weight (float) – Weight associated with the cosine loss function in optimization.
- class MoonContrastiveLossContainer(device, contrastive_loss_weight, temperature=0.5)[source]¶
Bases:
object
- class PerFclLossContainer(device, global_feature_contrastive_loss_weight, local_feature_contrastive_loss_weight, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]¶
Bases:
object
- __init__(device, global_feature_contrastive_loss_weight, local_feature_contrastive_loss_weight, global_feature_loss_temperature=0.5, local_feature_loss_temperature=0.5)[source]¶
Container to hold the different pieces associated with PerFCL Loss
- Parameters:
device (torch.device) – Device to which the loss will be sent and computed on.
global_feature_contrastive_loss_weight (float) – Weight on the global contrastive loss function.
local_feature_contrastive_loss_weight (float) – Weight on the local model contrastive loss function.
global_feature_loss_temperature (float, optional) – Temperature parameter on the global contrastive loss function. Defaults to 0.5.
local_feature_loss_temperature (float, optional) – Temperature parameter on the local contrastive loss function. Defaults to 0.5.