from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
import torch.nn as nn
[docs]
class AbstractAe(nn.Module, ABC):
[docs]
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
) -> None:
"""
The base class for all autoencoder based models. To define this model, we need to define the structure of
the encoder and the decoder modules. This type of model should have the capability to encode data using the
encoder module and decode the output of the encoder using the decoder module.
Args:
encoder (nn.Module): Model for encoding the input
decoder (nn.Module): Model for encoding the output. This module should be compatible with the output
structure of the encoder module.
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
[docs]
@abstractmethod
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Forward is called in client classes with a single input tensor.
raise NotImplementedError
[docs]
class BasicAe(AbstractAe):
[docs]
def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None:
"""
Standard auto-encoder structure. To define this model, we need to define the structure of the encoder and
the decoder modules. This type of model should have the capability to encode data using the encoder module
and decode the output of the encoder using the decoder module.
Args:
encoder (nn.Module): Model for encoding the input
decoder (nn.Module): Model for encoding the output. This module should be compatible with the output
structure of the encoder module.
"""
super().__init__(encoder, decoder)
[docs]
def encode(self, input: torch.Tensor) -> torch.Tensor:
"""
Defines the forward associated with encoding the provided input tensor. We reuse the forward for the encoder
module.
Args:
input (torch.Tensor): Input tensor to be encoded.
Returns:
torch.Tensor: Encoding associated with the input tensor.
"""
latent_vector = self.encoder(input)
return latent_vector
[docs]
def decode(self, latent_vector: torch.Tensor) -> torch.Tensor:
"""
Defines the forward associated with decoding a latent vector encoded by the encoder from some input.
Args:
latent_vector (torch.Tensor): Latent vector to be decoded
Returns:
torch.Tensor: Decoded tensor.
"""
output = self.decoder(latent_vector)
return output
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward function for the ``BasicAe`` model. It simply pieces the encoding and decoding forwards together
to reconstruct the input through the encoder-decoder pipeline.
Args:
input (torch.Tensor): Input to pass through the encoder
Returns:
torch.Tensor: Reconstructed input after encoding and decoding with the model.
"""
z = self.encode(input)
return self.decode(z)
[docs]
class VariationalAe(AbstractAe):
[docs]
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
) -> None:
"""Variational Auto-Encoder model base class.
Args:
encoder (nn.Module): Encoder module defined by the user.
decoder (nn.Module): Decoder module defined by the user.
"""
super().__init__(encoder, decoder)
[docs]
def encode(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Encodes the provided input using the provided encoder. That assumption is that the encoder produces a mean
and log variance value each of the same dimension from the input in the standard flow used by variational
autoencoders.
Args:
input (torch.Tensor): Input to be encoded
Returns:
tuple[torch.Tensor, torch.Tensor]: Mean and log variance values of the same dimensionality representing
the latent vector information to be used in VAE reconstruction.
"""
mu, logvar = self.encoder(input)
return mu, logvar
[docs]
def decode(self, latent_vector: torch.Tensor) -> torch.Tensor:
"""
From a latent vector, this function aims to reconstruct the input that was used to generate the provided
latent representation.
Args:
latent_vector (torch.Tensor): Latent vector. For VAEs this is a vector of some fixed dimension, given
logvar and :math:`\\mu` generated by the encoder as
.. math::
\\mu + \\epsilon \\cdot \\exp \\left(0.5 \\cdot \\text{logvar} \\right),
where :math:`\\epsilon \\sim \\mathcal{N}(\\mathbf{0}, I)`
Returns:
torch.Tensor: Decoding from the latent vector
"""
output = self.decoder(latent_vector)
return output
[docs]
def sampling(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
Sampling using the reparameterization trick to make the VAE differentiable. This sampling produces a vector
as if it were sampled from :math:`\\mathcal{N}\\left(\\mu, \\exp(0.5 \\cdot \\text{logvar}) I \\right)`
Args:
mu (torch.Tensor): Mean of the normal distribution from which to sample.
logvar (torch.Tensor): Log of the variance of the normal distribution from which to sample.
Returns:
torch.Tensor: Latent vector sampled from the appropriate normal distribution.
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
mu, logvar = self.encode(input)
z = self.sampling(mu, logvar)
output = self.decode(z)
# Output (reconstruction) is flattened to be concatenated with mu and logvar vectors.
# The shape of the flattened_output can be later restored by having the training data shape,
# or the decoder structure.
# This assumes output is "batch first".
flattened_output = output.view(output.shape[0], -1)
return torch.cat((logvar, mu, flattened_output), dim=1)
[docs]
class ConditionalVae(AbstractAe):
[docs]
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
unpack_input_condition: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> None:
"""Conditional Variational Auto-Encoder model.
Args:
encoder (nn.Module): The encoder used to map input to latent space.
decoder (nn.Module): The decoder used to reconstruct the input using a vector in latent space.
unpack_input_condition (Callable | None, optional): For unpacking the input and condition tensors.
"""
super().__init__(encoder, decoder)
self.unpack_input_condition = unpack_input_condition
[docs]
def encode(self, input: torch.Tensor, condition: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
"""
User can decide how to use the condition in the encoder, by defining the architecture and forward function
accordingly. Ex: Using the condition in the middle layers of encoder.
Args:
input (torch.Tensor): Input tensor
condition (torch.Tensor | None, optional): Conditional information to be used by the encoder. Defaults to
None.
Returns:
tuple[torch.Tensor, torch.Tensor]: mu and logvar, similar to the variational auto-encoder.
"""
mu, logvar = self.encoder(input, condition)
return mu, logvar
[docs]
def decode(self, latent_vector: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor:
"""
User can decide how to use the condition in the decoder, by defining the architecture and forward function to
inject the conditioning. Ex: Using the condition in the middle layers of decoder, or not using it at all.
Args:
latent_vector (torch.Tensor): The latent vector sampled from the distribution specified by the encoder.
For CVAEs this is a vector of some fixed dimension, given logvar and :math:`\\mu` generated by the
encoder, perhaps using conditional information.
.. math::
\\mu + \\epsilon \\cdot \\exp \\left(0.5 \\cdot \\text{logvar} \\right),
where :math:`\\epsilon \\sim \\mathcal{N}(\\mathbf{0}, I)`
condition (torch.Tensor | None, optional): Conditioning information to be used by the decoder during the
mapping of the ``latent_vector`` to an output. Defaults to None.
Returns:
torch.Tensor: Decoded tensor from the latent vector and (potentially) the conditioning vector.
"""
output = self.decoder(latent_vector, condition)
return output
[docs]
def sampling(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
Sampling using the reparameterization trick to make the CVAE differentiable. This sampling produces a vector
as if it were sampled from :math:`\\mathcal{N}\\left(\\mu, \\exp(0.5 \\cdot \\text{logvar}) I \\right)`
Args:
mu (torch.Tensor): Mean of the normal distribution from which to sample.
logvar (torch.Tensor): Log of the variance of the normal distribution from which to sample.
Returns:
torch.Tensor: Latent vector sampled from the appropriate normal distribution.
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the CVAE that orchestrates the (conditional) encoding and (conditional) decoding from the
provided input. This input may have conditional information packed into it. If the function
``unpack_input_condition`` is defined, it is used to un-pack the conditional information from the input tensor
This conditioning information is then provided to both the encoder and decoder modules. They may or may not
make use of this information, depending on how their architectures and forwards are defined.
**NOTE**: Output (reconstruction) is flattened to be concatenated with mu and logvar vectors. The shape of the
``flattened_output`` can be later restored by having the training data shape, or the decoder structure.
**NOTE**: This assumes output is "batch first".
Args:
input (torch.Tensor): Input vector.
Returns:
torch.Tensor: Reconstructed input vector, which has been flattened and concatenated with the mu and logvar
tensors.
"""
assert self.unpack_input_condition is not None
input, condition = self.unpack_input_condition(input)
mu, logvar = self.encode(input, condition)
z = self.sampling(mu, logvar)
output = self.decode(z, condition)
flattened_output = output.view(output.shape[0], -1)
return torch.cat((logvar, mu, flattened_output), dim=1)