fl4health.model_bases.autoencoders_base module

class AbstractAe(encoder, decoder)[source]

Bases: Module, ABC

__init__(encoder, decoder)[source]

The base class for all autoencoder based models. To define this model, we need to define the structure of the encoder and the decoder modules. This type of model should have the capability to encode data using the encoder module and decode the output of the encoder using the decoder module.

Parameters:
  • encoder (nn.Module) – Model for encoding the input

  • decoder (nn.Module) – Model for encoding the output. This module should be compatible with the output structure of the encoder module.

abstract forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class BasicAe(encoder, decoder)[source]

Bases: AbstractAe

__init__(encoder, decoder)[source]

Standard auto-encoder structure. To define this model, we need to define the structure of the encoder and the decoder modules. This type of model should have the capability to encode data using the encoder module and decode the output of the encoder using the decoder module.

Parameters:
  • encoder (nn.Module) – Model for encoding the input

  • decoder (nn.Module) – Model for encoding the output. This module should be compatible with the output structure of the encoder module.

decode(latent_vector)[source]

Defines the forward associated with decoding a latent vector encoded by the encoder from some input.

Parameters:

latent_vector (torch.Tensor) – Latent vector to be decoded

Returns:

Decoded tensor.

Return type:

torch.Tensor

encode(input)[source]

Defines the forward associated with encoding the provided input tensor. We reuse the forward for the encoder module.

Parameters:

input (torch.Tensor) – Input tensor to be encoded.

Returns:

Encoding associated with the input tensor.

Return type:

torch.Tensor

forward(input)[source]

Forward function for the BasicAe model. It simply pieces the encoding and decoding forwards together to reconstruct the input through the encoder-decoder pipeline.

Parameters:

input (torch.Tensor) – Input to pass through the encoder

Returns:

Reconstructed input after encoding and decoding with the model.

Return type:

torch.Tensor

class ConditionalVae(encoder, decoder, unpack_input_condition=None)[source]

Bases: AbstractAe

__init__(encoder, decoder, unpack_input_condition=None)[source]

Conditional Variational Auto-Encoder model.

Parameters:
  • encoder (nn.Module) – The encoder used to map input to latent space.

  • decoder (nn.Module) – The decoder used to reconstruct the input using a vector in latent space.

  • unpack_input_condition (Callable | None, optional) – For unpacking the input and condition tensors.

decode(latent_vector, condition=None)[source]

User can decide how to use the condition in the decoder, by defining the architecture and forward function to inject the conditioning. Ex: Using the condition in the middle layers of decoder, or not using it at all.

Parameters:
  • latent_vector (torch.Tensor) –

    The latent vector sampled from the distribution specified by the encoder. For CVAEs this is a vector of some fixed dimension, given logvar and \(\mu\) generated by the encoder, perhaps using conditional information.

    \[\mu + \epsilon \cdot \exp \left(0.5 \cdot \text{logvar} \right),\]

    where \(\epsilon \sim \mathcal{N}(\mathbf{0}, I)\)

  • condition (torch.Tensor | None, optional) – Conditioning information to be used by the decoder during the mapping of the latent_vector to an output. Defaults to None.

Returns:

Decoded tensor from the latent vector and (potentially) the conditioning vector.

Return type:

torch.Tensor

encode(input, condition=None)[source]

User can decide how to use the condition in the encoder, by defining the architecture and forward function accordingly. Ex: Using the condition in the middle layers of encoder.

Parameters:
  • input (torch.Tensor) – Input tensor

  • condition (torch.Tensor | None, optional) – Conditional information to be used by the encoder. Defaults to None.

Returns:

mu and logvar, similar to the variational auto-encoder.

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(input)[source]

Forward pass for the CVAE that orchestrates the (conditional) encoding and (conditional) decoding from the provided input. This input may have conditional information packed into it. If the function unpack_input_condition is defined, it is used to un-pack the conditional information from the input tensor This conditioning information is then provided to both the encoder and decoder modules. They may or may not make use of this information, depending on how their architectures and forwards are defined.

NOTE: Output (reconstruction) is flattened to be concatenated with mu and logvar vectors. The shape of the flattened_output can be later restored by having the training data shape, or the decoder structure.

NOTE: This assumes output is “batch first”.

Parameters:

input (torch.Tensor) – Input vector.

Returns:

Reconstructed input vector, which has been flattened and concatenated with the mu and logvar tensors.

Return type:

torch.Tensor

sampling(mu, logvar)[source]

Sampling using the reparameterization trick to make the CVAE differentiable. This sampling produces a vector as if it were sampled from \(\mathcal{N}\left(\mu, \exp(0.5 \cdot \text{logvar}) I \right)\)

Parameters:
  • mu (torch.Tensor) – Mean of the normal distribution from which to sample.

  • logvar (torch.Tensor) – Log of the variance of the normal distribution from which to sample.

Returns:

Latent vector sampled from the appropriate normal distribution.

Return type:

torch.Tensor

class VariationalAe(encoder, decoder)[source]

Bases: AbstractAe

__init__(encoder, decoder)[source]

Variational Auto-Encoder model base class.

Parameters:
  • encoder (nn.Module) – Encoder module defined by the user.

  • decoder (nn.Module) – Decoder module defined by the user.

decode(latent_vector)[source]

From a latent vector, this function aims to reconstruct the input that was used to generate the provided latent representation.

Parameters:

latent_vector (torch.Tensor) –

Latent vector. For VAEs this is a vector of some fixed dimension, given logvar and \(\mu\) generated by the encoder as

\[\mu + \epsilon \cdot \exp \left(0.5 \cdot \text{logvar} \right),\]

where \(\epsilon \sim \mathcal{N}(\mathbf{0}, I)\)

Returns:

Decoding from the latent vector

Return type:

torch.Tensor

encode(input)[source]

Encodes the provided input using the provided encoder. That assumption is that the encoder produces a mean and log variance value each of the same dimension from the input in the standard flow used by variational autoencoders.

Parameters:

input (torch.Tensor) – Input to be encoded

Returns:

Mean and log variance values of the same dimensionality representing the latent vector information to be used in VAE reconstruction.

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

sampling(mu, logvar)[source]

Sampling using the reparameterization trick to make the VAE differentiable. This sampling produces a vector as if it were sampled from \(\mathcal{N}\left(\mu, \exp(0.5 \cdot \text{logvar}) I \right)\)

Parameters:
  • mu (torch.Tensor) – Mean of the normal distribution from which to sample.

  • logvar (torch.Tensor) – Log of the variance of the normal distribution from which to sample.

Returns:

Latent vector sampled from the appropriate normal distribution.

Return type:

torch.Tensor