Source code for fl4health.losses.fenda_loss_config

import torch

from fl4health.losses.contrastive_loss import MoonContrastiveLoss
from fl4health.losses.cosine_similarity_loss import CosineSimilarityLoss
from fl4health.losses.perfcl_loss import PerFclLoss


[docs] class PerFclLossContainer: def __init__( self, device: torch.device, global_feature_contrastive_loss_weight: float, local_feature_contrastive_loss_weight: float, global_feature_loss_temperature: float = 0.5, local_feature_loss_temperature: float = 0.5, ) -> None: self.global_feature_contrastive_loss_weight = global_feature_contrastive_loss_weight self.local_feature_contrastive_loss_weight = local_feature_contrastive_loss_weight self.perfcl_loss_function = PerFclLoss(device, global_feature_loss_temperature, local_feature_loss_temperature)
[docs] class CosineSimilarityLossContainer: def __init__(self, device: torch.device, cos_sim_loss_weight: float) -> None: self.cos_sim_loss_weight = cos_sim_loss_weight self.cos_sim_loss_function = CosineSimilarityLoss(device)
[docs] class MoonContrastiveLossContainer: def __init__(self, device: torch.device, contrastive_loss_weight: float, temperature: float = 0.5) -> None: self.contrastive_loss_weight = contrastive_loss_weight self.contrastive_loss_function = MoonContrastiveLoss(device, temperature)
[docs] class ConstrainedFendaLossContainer: def __init__( self, perfcl_loss_config: PerFclLossContainer | None, cosine_similarity_loss_config: CosineSimilarityLossContainer | None, contrastive_loss_config: MoonContrastiveLossContainer | None, ) -> None: self.perfcl_loss_config = perfcl_loss_config self.cos_sim_loss_config = cosine_similarity_loss_config self.contrastive_loss_config = contrastive_loss_config
[docs] def has_perfcl_loss(self) -> bool: return self.perfcl_loss_config is not None
[docs] def has_cosine_similarity_loss(self) -> bool: return self.cos_sim_loss_config is not None
[docs] def has_contrastive_loss(self) -> bool: return self.contrastive_loss_config is not None
[docs] def compute_contrastive_loss( self, features: torch.Tensor, positive_pairs: torch.Tensor, negative_pairs: torch.Tensor ) -> torch.Tensor: assert self.contrastive_loss_config is not None contrastive_loss = self.contrastive_loss_config.contrastive_loss_function( features, positive_pairs, negative_pairs ) return self.contrastive_loss_config.contrastive_loss_weight * contrastive_loss
[docs] def compute_cosine_similarity_loss( self, first_features: torch.Tensor, second_features: torch.Tensor ) -> torch.Tensor: assert self.cos_sim_loss_config is not None cosine_similarity_loss = self.cos_sim_loss_config.cos_sim_loss_function(first_features, second_features) return self.cos_sim_loss_config.cos_sim_loss_weight * cosine_similarity_loss
[docs] def compute_perfcl_loss( self, local_features: torch.Tensor, old_local_features: torch.Tensor, global_features: torch.Tensor, old_global_features: torch.Tensor, initial_global_features: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: assert self.perfcl_loss_config is not None global_feature_contrastive_loss, local_feature_contrastive_loss = self.perfcl_loss_config.perfcl_loss_function( local_features, old_local_features, global_features, old_global_features, initial_global_features ) return ( self.perfcl_loss_config.global_feature_contrastive_loss_weight * global_feature_contrastive_loss, self.perfcl_loss_config.local_feature_contrastive_loss_weight * local_feature_contrastive_loss, )