Source code for fl4health.preprocessing.pca_preprocessor
from functools import partial
from pathlib import Path
import torch
from fl4health.model_bases.pca import PcaModule
from fl4health.utils.dataset import TensorDataset
[docs]
class PcaPreprocessor:
[docs]
def __init__(self, checkpointing_path: Path) -> None:
"""
Class that leverages pre-computed principal components of
a dataset to perform data-preprocessing.
Args:
checkpointing_path (Path): Path to saved principal components.
"""
self.checkpointing_path = checkpointing_path
self.pca_module: PcaModule = self.load_pca_module()
[docs]
def load_pca_module(self) -> PcaModule:
pca_module = torch.load(self.checkpointing_path)
pca_module.eval()
return pca_module
[docs]
def reduce_dimension(
self,
new_dimension: int,
dataset: TensorDataset,
) -> TensorDataset:
"""
Perform dimensionality reduction on a dataset by projecting the data
onto a set of pre-computed principal components.
(Note that PyTorch dataloaders perform lazy application of transforms.
So in reality, dimensionality reduction is applied in real-time as the user
iterates through the dataloader created from the dataset returned here.)
Args:
new_dimension (int): New data dimension after dimensionality reduction. Equals
the number of principal components onto which projection is performed.
dataset (BaseDataset): Dataset containing data whose dimension is to be reduced.
Returns:
BaseDataset: Dataset consisting of data with reduced dimension.
"""
projection = partial(self.pca_module.project_lower_dim, k=new_dimension)
dataset.update_transform(projection)
return dataset