import random
import string
from logging import INFO
from pathlib import Path
import torch
from flwr.client.numpy_client import NumPyClient
from flwr.common import Config, NDArrays, Scalar
from flwr.common.logger import log
from torch import Tensor
from torch.utils.data import DataLoader
from fl4health.model_bases.pca import PcaModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.utils.config import narrow_dict_type
[docs]
class FedPCAClient(NumPyClient):
[docs]
def __init__(self, data_path: Path, device: torch.device, model_save_path: Path) -> None:
"""
Client that facilitates the execution of federated PCA.
Args:
data_path (Path): path to the data to be used to load the data for client-side training
device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or
'cuda'
model_save_path (Path): Path to save the PCA components for use later, perhaps in dimensionality reduction
"""
self.client_name = self.generate_hash()
self.model: PcaModule
self.initialized = False
self.data_path = data_path
self.model_save_path = model_save_path
self.device = device
self.train_data_tensor: Tensor
self.val_data_tensor: Tensor
self.num_train_samples: int
self.num_val_samples: int
self.parameter_exchanger = FullParameterExchanger()
[docs]
def generate_hash(self, length: int = 8) -> str:
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
[docs]
def get_parameters(self, config: Config) -> NDArrays:
if not self.initialized:
log(INFO, "Setting up client and providing full model parameters to the server for initialization")
# If initialized==False, the server is requesting model parameters from which to initialize all other
# clients. As such get_parameters is being called before fit or evaluate, so we must call
# setup_client first.
self.setup_client(config)
# Need all parameters even if normally exchanging partial
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
assert self.model is not None and self.parameter_exchanger is not None
return self.parameter_exchanger.push_parameters(self.model, config=config)
[docs]
def set_parameters(self, parameters: NDArrays, config: Config) -> None:
"""
Sets the merged principal components transferred from the server.
Since federated PCA only runs for one round, the principal components obtained here
are in fact the final result, so they are saved locally by each client for downstream tasks.
"""
self.parameter_exchanger.pull_parameters(parameters, self.model, config)
self.save_model()
[docs]
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
"""
User defined method that returns a PyTorch Train DataLoader
and a PyTorch Validation DataLoader
"""
raise NotImplementedError
[docs]
def get_model(self, config: Config) -> PcaModule:
"""
Returns an instance of the PCAModule.
"""
low_rank = narrow_dict_type(config, "low_rank", bool)
full_svd = narrow_dict_type(config, "full_svd", bool)
rank_estimation = narrow_dict_type(config, "rank_estimation", int)
return PcaModule(low_rank, full_svd, rank_estimation).to(self.device)
[docs]
def setup_client(self, config: Config) -> None:
self.model = self.get_model(config).to(self.device)
train_loader, val_loader = self.get_data_loaders(config)
self.train_data_tensor = self.get_data_tensor(train_loader).to(self.device)
self.val_data_tensor = self.get_data_tensor(val_loader).to(self.device)
# The following lines are type ignored because torch datasets are not "Sized"
# IE __len__ is considered optionally defined. In practice, it is almost always defined
# and as such, we will make that assumption.
self.num_train_samples = len(train_loader.dataset) # type: ignore
self.num_val_samples = len(val_loader.dataset) # type: ignore
self.initialized = True
[docs]
def get_data_tensor(self, data_loader: DataLoader) -> Tensor:
raise NotImplementedError
[docs]
def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]:
"""Perform PCA using the locally held dataset."""
if not self.initialized:
self.setup_client(config)
center_data = narrow_dict_type(config, "center_data", bool)
principal_components, singular_values = self.model(self.train_data_tensor, center_data)
self.model.set_principal_components(principal_components, singular_values)
cumulative_explained_variance = self.model.compute_cumulative_explained_variance()
explained_variance_ratios = self.model.compute_explained_variance_ratios()
metrics: dict[str, Scalar] = {
"cumulative_explained_variance": cumulative_explained_variance,
"top_explained_variance_ratio": explained_variance_ratios[0].item(),
}
return (self.get_parameters(config), self.num_train_samples, metrics)
[docs]
def evaluate(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[float, int, dict[str, Scalar]]:
"""
Evaluate merged principal components on the local validation set.
Args:
parameters (NDArrays): Server-merged principal components.
config (dict[str, Scalar]): Config file.
"""
if not self.initialized:
self.setup_client(config)
self.model.set_data_mean(self.model.maybe_reshape(self.train_data_tensor))
self.set_parameters(parameters, config)
num_components_eval = (
narrow_dict_type(config, "num_components_eval", int) if "num_components_eval" in config.keys() else None
)
val_data_tensor_prepared = self.model.center_data(self.model.maybe_reshape(self.val_data_tensor)).to(
self.device
)
reconstruction_loss = self.model.compute_reconstruction_error(val_data_tensor_prepared, num_components_eval)
projection_variance = self.model.compute_projection_variance(val_data_tensor_prepared, num_components_eval)
metrics: dict[str, Scalar] = {"projection_variance": projection_variance}
return (reconstruction_loss, self.num_val_samples, metrics)
[docs]
def save_model(self) -> None:
final_model_save_path = f"{self.model_save_path}/client_{self.generate_hash()}_pca.pt"
torch.save(self.model, final_model_save_path)
log(INFO, f"Model parameters saved to {final_model_save_path}.")