fl4health.preprocessing.autoencoders.loss module¶
- class VaeLoss(latent_dim, base_loss)[source]¶
Bases:
_Loss- __init__(latent_dim, base_loss)[source]¶
The loss function used for training CVAEs and VAEs.
This loss computes the
base_loss(defined by the user) between the input and generated output. It then adds the KL divergence between the estimated distribution (represented by mu and logvar) and the standard normal distribution.- Parameters:
latent_dim (int) – Dimensionality of the latent space.
base_loss (_Loss) – Base loss function between the input and reconstruction.
- forward(preds, target)[source]¶
Calculates the total loss.
- Parameters:
preds (torch.Tensor) – Model predictions.
target (torch.Tensor) – Target values.
- Returns:
Total loss composed of base loss and KL divergence loss.
- Return type:
torch.Tensor
- standard_normal_kl_divergence_loss(mu, logvar)[source]¶
Calculates the analytical KL divergence between the normal distribution and the estimated distribution.
- Parameters:
mu (torch.Tensor) – Mean of the estimated distribution.
logvar (torch.Tensor) – Log variance of the estimated distribution.
- Returns:
KL divergence loss.
- Return type:
torch.Tensor