fl4health.preprocessing.autoencoders.loss module

class VaeLoss(latent_dim, base_loss)[source]

Bases: _Loss

__init__(latent_dim, base_loss)[source]

The loss function used for training CVAEs and VAEs. This loss computes the base_loss (defined by the user) between the input and generated output. It then adds the KL divergence between the estimated distribution (represented by mu and logvar) and the standard normal distribution.

Parameters:
  • latent_dim (int) – Dimensionality of the latent space.

  • base_loss (_Loss) – Base loss function between the input and reconstruction.

forward(preds, target)[source]

Calculates the total loss.

Parameters:
  • preds (torch.Tensor) – Model predictions.

  • target (torch.Tensor) – Target values.

Returns:

Total loss composed of base loss and KL divergence loss.

Return type:

torch.Tensor

standard_normal_kl_divergence_loss(mu, logvar)[source]

Calculates the analytical KL divergence between the normal distribution and the estimated distribution.

Parameters:
  • mu (torch.Tensor) – Mean of the estimated distribution.

  • logvar (torch.Tensor) – Log variance of the estimated distribution.

Returns:

KL divergence loss.

Return type:

torch.Tensor

unpack_model_output(preds)[source]

Unpacks the model output tensor.

Parameters:

preds (torch.Tensor) – Model predictions.

Returns:

Unpacked output containing predictions, mu, and logvar.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]