fl4health.preprocessing.autoencoders.loss module¶
- class VaeLoss(latent_dim, base_loss)[source]¶
Bases:
_Loss
- __init__(latent_dim, base_loss)[source]¶
The loss function used for training CVAEs and VAEs. This loss computes the base_loss (defined by the user) between the input and generated output. It then adds the KL divergence between the estimated distribution (represented by mu and logvar) and the standard normal distribution.
- Parameters:
latent_dim (int) – Dimensionality of the latent space.
base_loss (_Loss) – Base loss function between the input and reconstruction.
- forward(preds, target)[source]¶
Calculates the total loss.
- Parameters:
preds (torch.Tensor) – Model predictions.
target (torch.Tensor) – Target values.
- Returns:
Total loss composed of base loss and KL divergence loss.
- Return type:
torch.Tensor
- standard_normal_kl_divergence_loss(mu, logvar)[source]¶
Calculates the analytical KL divergence between the normal distribution and the estimated distribution.
- Parameters:
mu (torch.Tensor) – Mean of the estimated distribution.
logvar (torch.Tensor) – Log variance of the estimated distribution.
- Returns:
KL divergence loss.
- Return type:
torch.Tensor