from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
import torch.nn as nn
[docs]
class AbstractAe(nn.Module, ABC):
[docs]
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
) -> None:
"""
The base class for all autoencoder based models.
To define this model, we need to define the structure of the encoder and the decoder modules.
This type of model should have the capability to encode data using the encoder module and decode
the output of the encoder using the decoder module.
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
[docs]
@abstractmethod
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Forward is called in client classes with a single input tensor.
raise NotImplementedError
[docs]
class BasicAe(AbstractAe):
def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None:
super().__init__(encoder, decoder)
[docs]
def encode(self, input: torch.Tensor) -> torch.Tensor:
latent_vector = self.encoder(input)
return latent_vector
[docs]
def decode(self, latent_vector: torch.Tensor) -> torch.Tensor:
output = self.decoder(latent_vector)
return output
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
z = self.encode(input)
return self.decode(z)
[docs]
class VariationalAe(AbstractAe):
[docs]
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
) -> None:
"""Variational Auto-Encoder model base class.
Args:
encoder (nn.Module): Encoder module defined by the user.
decoder (nn.Module): Decoder module defined by the user.
"""
super().__init__(encoder, decoder)
[docs]
def encode(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
mu, logvar = self.encoder(input)
return mu, logvar
[docs]
def decode(self, latent_vector: torch.Tensor) -> torch.Tensor:
output = self.decoder(latent_vector)
return output
[docs]
def sampling(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
mu, logvar = self.encode(input)
z = self.sampling(mu, logvar)
output = self.decode(z)
# Output (reconstruction) is flattened to be concatenated with mu and logvar vectors.
# The shape of the flattened_output can be later restored by having the training data shape,
# or the decoder structure.
# This assumes output is "batch first".
flattened_output = output.view(output.shape[0], -1)
return torch.cat((logvar, mu, flattened_output), dim=1)
[docs]
class ConditionalVae(AbstractAe):
[docs]
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
unpack_input_condition: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> None:
"""Conditional Variational Auto-Encoder model.
Args:
encoder (nn.Module): The encoder used to map input to latent space.
decoder (nn.Module): The decoder used to reconstruct the input using a vector in latent space.
unpack_input_condition (Callable | None, optional): For unpacking the input and condition tensors.
"""
super().__init__(encoder, decoder)
self.unpack_input_condition = unpack_input_condition
[docs]
def encode(self, input: torch.Tensor, condition: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
# User can decide how to use the condition in the encoder,
# ex: using the condition in the middle layers of encoder.
mu, logvar = self.encoder(input, condition)
return mu, logvar
[docs]
def decode(self, latent_vector: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor:
# User can decide how to use the condition in the decoder,
# ex: using the condition in the middle layers of decoder, or not using it at all.
output = self.decoder(latent_vector, condition)
return output
[docs]
def sampling(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert self.unpack_input_condition is not None
input, condition = self.unpack_input_condition(input)
mu, logvar = self.encode(input, condition)
z = self.sampling(mu, logvar)
output = self.decode(z, condition)
# Output (reconstruction) is flattened to be concatenated with mu and logvar vectors.
# The shape of the flattened_output can be later restored by having the training data shape,
# or the decoder structure.
# This assumes output is "batch first".
flattened_output = output.view(output.shape[0], -1)
return torch.cat((logvar, mu, flattened_output), dim=1)