Source code for fl4health.clients.fed_pca_client

import random
import string
from logging import INFO
from pathlib import Path

import torch
from flwr.client.numpy_client import NumPyClient
from flwr.common import Config, NDArrays, Scalar
from flwr.common.logger import log
from torch import Tensor
from torch.utils.data import DataLoader

from fl4health.model_bases.pca import PcaModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.utils.config import narrow_dict_type


[docs] class FedPCAClient(NumPyClient):
[docs] def __init__( self, data_path: Path, device: torch.device, model_save_dir: Path, client_name: str | None = None ) -> None: """ Client that facilitates the execution of federated PCA. Args: data_path (Path): path to the data to be used to load the data for client-side training device (torch.device): Device indicator for where to send the model, batches, labels etc. Often "cpu" or "cuda" model_save_dir (Path): Dir to save the PCA components for use later, perhaps in dimensionality reduction client_name (str | None, optional): client name, mainly used for saving components. Defaults to None. """ self.client_name = self.generate_hash() if client_name is None else client_name self.model: PcaModule self.initialized = False self.data_path = data_path self.model_save_dir = model_save_dir self.device = device self.train_data_tensor: Tensor self.val_data_tensor: Tensor self.num_train_samples: int self.num_val_samples: int self.parameter_exchanger = FullParameterExchanger()
[docs] def generate_hash(self, length: int = 8) -> str: """ Generates unique hash used as id for client. Args: length (int, optional): Length of the generated hash. Defaults to 8. Returns: str: Generated hash of length ``length`` """ return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
[docs] def get_parameters(self, config: Config) -> NDArrays: """ Sends all of the model components back to the server. The model in this case represents the principal components that have been computed Args: config (Config): Configurations to allow for customization of this functions behavior Returns: NDArrays: Parameters representing the principal components computed by the client that need to be aggregated in some way """ if not self.initialized: log(INFO, "Setting up client and providing full model parameters to the server for initialization") # If initialized==False, the server is requesting model parameters from which to initialize all other # clients. As such get_parameters is being called before fit or evaluate, so we must call # setup_client first. self.setup_client(config) # Need all parameters even if normally exchanging partial return FullParameterExchanger().push_parameters(self.model, config=config) else: assert self.model is not None and self.parameter_exchanger is not None return self.parameter_exchanger.push_parameters(self.model, config=config)
[docs] def set_parameters(self, parameters: NDArrays, config: Config) -> None: """ Sets the merged principal components transferred from the server. Since federated PCA only runs for one round, the principal components obtained here are in fact the final result, so they are saved locally by each client for downstream tasks. Args: parameters (NDArrays): Aggregated principal components from the server. These are **FINAL** in the sense that FedPCA only runs for one round. config (Config): Configurations to allow for customization of this functions behavior """ self.parameter_exchanger.pull_parameters(parameters, self.model, config) self.save_model()
[docs] def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: """ User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader Args: config (Config): Configurations sent by the server to allow for customization of this functions behavior. Returns: tuple[DataLoader, DataLoader] Tuple of length 2. The client train and validation loader. Raises: NotImplementedError: To be defined in child class. """ raise NotImplementedError
[docs] def get_model(self, config: Config) -> PcaModule: """ Returns an instance of the ``PCAModule``. This module is used to facilitate FedPCA training on the client side Args: config (Config): Configurations sent by the server to allow for customization of this functions behavior. Returns: PcaModule: Module that determines how local FedPCA optimization will be performed. """ low_rank = narrow_dict_type(config, "low_rank", bool) full_svd = narrow_dict_type(config, "full_svd", bool) rank_estimation = narrow_dict_type(config, "rank_estimation", int) return PcaModule(low_rank, full_svd, rank_estimation).to(self.device)
[docs] def setup_client(self, config: Config) -> None: """ Used to setup all of the components necessary to run FedPCA Args: config (Config): Configurations sent by the server to allow for customization of this functions behavior. """ self.model = self.get_model(config).to(self.device) train_loader, val_loader = self.get_data_loaders(config) self.train_data_tensor = self.get_data_tensor(train_loader).to(self.device) self.val_data_tensor = self.get_data_tensor(val_loader).to(self.device) # The following lines are type ignored because torch datasets are not "Sized" # IE __len__ is considered optionally defined. In practice, it is almost always defined # and as such, we will make that assumption. self.num_train_samples = len(train_loader.dataset) # type: ignore self.num_val_samples = len(val_loader.dataset) # type: ignore self.initialized = True
[docs] def get_data_tensor(self, data_loader: DataLoader) -> Tensor: """ This function should be used to "collate" each of the dataloader batches into a single monolithic tensor representing all of the data in the loader. Args: data_loader (DataLoader): The dataloader that can be used to iterate through a dataset Raises: NotImplementedError: Should be defined by the child class Returns: Tensor: Single torch tensor representing all of the data stacked together. """ raise NotImplementedError
[docs] def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: """ Function to perform the local side of FedPCA. We don't use any parameters sent by the server. Hence ``parameters`` is ignored. We need only the ``train_data_tensor`` to do the work Args: parameters (NDArrays): ignored config (Config): Configurations sent by the server to allow for customization of this functions behavior. Returns: tuple[NDArrays, int, dict[str, Scalar]]: The local principal components following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit. """ if not self.initialized: self.setup_client(config) center_data = narrow_dict_type(config, "center_data", bool) principal_components, singular_values = self.model(self.train_data_tensor, center_data) self.model.set_principal_components(principal_components, singular_values) cumulative_explained_variance = self.model.compute_cumulative_explained_variance() explained_variance_ratios = self.model.compute_explained_variance_ratios() metrics: dict[str, Scalar] = { "cumulative_explained_variance": cumulative_explained_variance, "top_explained_variance_ratio": explained_variance_ratios[0].item(), } return (self.get_parameters(config), self.num_train_samples, metrics)
[docs] def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: """ Evaluate merged principal components on the local validation set. Args: parameters (NDArrays): Server-merged principal components. config (Config): Configurations sent by the server to allow for customization of this functions behavior. Returns: tuple[float, int, dict[str, Scalar]]: A loss associated with the evaluation, the number of samples in the validation/test set and the ``metric_values`` associated with evaluation. """ if not self.initialized: self.setup_client(config) self.model.set_data_mean(self.model.maybe_reshape(self.train_data_tensor)) self.set_parameters(parameters, config) num_components_eval = ( narrow_dict_type(config, "num_components_eval", int) if "num_components_eval" in config.keys() else None ) val_data_tensor_prepared = self.model.center_data(self.model.maybe_reshape(self.val_data_tensor)).to( self.device ) reconstruction_loss = self.model.compute_reconstruction_error(val_data_tensor_prepared, num_components_eval) projection_variance = self.model.compute_projection_variance(val_data_tensor_prepared, num_components_eval) metrics: dict[str, Scalar] = {"projection_variance": projection_variance} return (reconstruction_loss, self.num_val_samples, metrics)
[docs] def save_model(self) -> None: """ Method to save the FedPCA computed principal components to disk. These can be reloaded to allow for dimensionality reduction in subsequent FL training. """ final_model_save_path = self.model_save_dir / f"client_{self.client_name}_pca.pt" torch.save(self.model, final_model_save_path) log(INFO, f"Model parameters saved to {final_model_save_path}.")