Source code for fl4health.preprocessing.autoencoders.dim_reduction

from abc import ABC
from pathlib import Path

import torch


[docs] class AutoEncoderProcessing(ABC):
[docs] def __init__( self, checkpointing_path: Path, device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), ) -> None: """Abstract class for processors that work with a pre-trained and saved autoencoder model. Args: checkpointing_path (Path): Path to the saved model. device (torch.device, optional): Device indicator for where to send the model and data samples for preprocessing. """ self.checkpointing_path = checkpointing_path self.device = device self.load_autoencoder()
[docs] def load_autoencoder(self) -> None: autoencoder = torch.load(self.checkpointing_path) autoencoder.eval() self.autoencoder = autoencoder.to(self.device)
def __repr__(self) -> str: return f"{self.__class__.__name__}"
[docs] class AeProcessor(AutoEncoderProcessing): """ Transformer processor to encode the data using basic autoencoder. """ def __call__(self, sample: torch.Tensor) -> torch.Tensor: # This transformer is called for the input samples after they are transferred into torch tensors. embedding_vector = self.autoencoder.encode(sample.to(self.device)) return embedding_vector.clone().detach()
[docs] class VaeProcessor(AutoEncoderProcessing):
[docs] def __init__( self, checkpointing_path: Path, device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), return_mu_only: bool = False, ) -> None: """ Transformer processor to encode the data using VAE encoder. Args: checkpointing_path (Path): Path to the saved model. device (torch.device, optional): Device indicator for where to send the model and data samples for preprocessing. return_mu_only (bool, optional): If true, only mu is returned. Defaults to False. """ super().__init__(checkpointing_path, device) self.return_mu_only = return_mu_only
def __call__(self, sample: torch.Tensor) -> torch.Tensor: # This transformer is called for the input samples after they are transferred into torch tensors. mu, logvar = self.autoencoder.encode(sample.to(self.device)) if self.return_mu_only: return mu.clone().detach() # By default returns cat(mu,logvar). # Concatenation is performed on the last dimension which for both the batched data and single data # is the latent space dimension. return torch.cat((mu.clone().detach(), logvar.clone().detach()), dim=-1)
[docs] class CvaeFixedConditionProcessor(AutoEncoderProcessing): """Transformer processor to encode the data using CVAE encoder with client-specific condition."""
[docs] def __init__( self, checkpointing_path: Path, condition: torch.Tensor, device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), return_mu_only: bool = False, ) -> None: """Transformer processor to encode the data using a CVAE encoder with a fixed condition. Args: checkpointing_path (Path): Path to the saved model. condition (torch.Tensor): Fixed condition tensor. device (torch.device, optional): Device indicator for where to send the model and data samples for preprocessing. return_mu_only (bool, optional): If true, only mu is returned. Defaults to False. """ super().__init__(checkpointing_path, device) self.condition = condition self.return_mu_only = return_mu_only assert ( self.condition.dim() == 1 ), f"Error: condition should be a 1D vector instead of a {self.condition.dim()}D tensor."
def __call__(self, sample: torch.Tensor) -> torch.Tensor: # Assuming batch is the first dimension if sample.dim() == 1: batch_condition = self.condition else: # If we are processing a batch of data, condition should be repeated for each sample. sample_batch_size = sample.shape[0] batch_condition = self.condition.repeat(sample_batch_size, 1) # This transformer is called for the input samples after they are transformed into torch tensors. mu, logvar = self.autoencoder.encode(sample.to(self.device), batch_condition.to(self.device)) if self.return_mu_only: return mu.clone().detach() # By default returns cat(mu,logvar) # Concatenation is performed on the last dimension which for both the batched data and single data # is the latent space dimension. return torch.cat((mu.clone().detach(), logvar.clone().detach()), dim=-1)
[docs] class CvaeVariableConditionProcessor(AutoEncoderProcessing):
[docs] def __init__( self, checkpointing_path: Path, device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), return_mu_only: bool = False, ) -> None: """ Transformer processor to encode the data using CVAE encoder with variable condition, that is each data sample can have a specific condition. Args: checkpointing_path (Path): Path to the saved model. device (torch.device, optional): Device indicator for where to send the model and data samples for preprocessing. return_mu_only (bool, optional): If true, only mu is returned. Defaults to False. """ super().__init__(checkpointing_path, device) self.return_mu_only = return_mu_only
def __call__(self, sample: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: """Performs encoding. Args: sample (torch.Tensor): A single data sample or a batch of data. condition (torch.Tensor): A single condition for the given sample, or a batch of condition for a batch of data. Returns: torch.Tensor: Encoded sample(s). """ # This transformer is called for the input samples after they are transformed into torch tensors. # We assume condition and data are "batch first". if condition.size(0) > 1: # If condition is a batch assert condition.size(0) == sample.size( 0 ), f"Error: Condition shape: {condition.shape} does not match the data shape: {sample.shape}" mu, logvar = self.autoencoder.encode(sample.to(self.device), condition.to(self.device)) if self.return_mu_only: return mu.clone().detach() # By default returns cat(mu,logvar) # Concatenation is performed on the last dimension which for both the batched data and single data # is the latent space dimension. return torch.cat((mu.clone().detach(), logvar.clone().detach()), dim=-1)