"""Implementation of the clients and the Client enumeration."""
from enum import Enum
from typing import List
import torch
from fl4health.clients.basic_client import BasicClient
from fl4health.clients.fed_prox_client import FedProxClient
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.sampler import DirichletLabelBasedSampler
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from florist.api.clients.optimizers import Optimizer
from florist.api.models.abstract import LocalDataModel
from florist.api.servers.strategies import Strategy
[docs]
class LocalDataClient(BasicClient): # type: ignore[misc]
"""Implementation of a client that uses a model with data stored locally."""
[docs]
def set_model(self, model: LocalDataModel) -> None:
"""
Set the model to be used for training with local data.
:param model: (LocalModel) An instance of the model to be used for training.
"""
self.model = model
[docs]
def set_optimizer_type(self, optimizer_type: Optimizer) -> None:
"""
Set the type of the optimizer to be used for training.
:param optimizer_type: (Optimizer) A value of the Optimizer enumeration with the type of
the optimizer to be used for training.
"""
self.optimizer_type = optimizer_type
[docs]
def get_model(self, config: Config) -> torch.nn.Module:
"""
Return the model for training with local data.
:param config: (Config) the Config object for this client.
:return: (torch.nn.Module) An instance of the model.
"""
assert isinstance(self.model, LocalDataModel), f"Model {self.model} is not a subclass of LocalModel."
return self.model
[docs]
def get_optimizer(self, config: Config) -> torch.optim.Optimizer: # type: ignore
"""
Return the optimizer for the model.
:param config: (Config) the Config object for this client.
:return: (torch.optim.Optimizer) An instance of torch.optim.Optimizer with the configurations defined
by self.optimizer_type.
"""
assert isinstance(self.model, LocalDataModel), f"Model {self.model} is not a subclass of LocalModel."
assert self.optimizer_type, "self.optimizer_type is None."
return Optimizer.get(self.optimizer_type, self.model.parameters())
[docs]
def get_data_loaders(self, config: Config) -> tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]:
"""
Return the data loader for the model with local data.
:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
assert isinstance(self.model, LocalDataModel), f"Model {self.model} is not a subclass of LocalModel."
assert self.data_path, "self.data_path is empty None."
return self.model.get_data_loaders(self.data_path, int(config["batch_size"]))
[docs]
def get_criterion(self, config: Config) -> _Loss:
"""
Return the loss for the model.
:param config: (Config) the Config object for this client.
:return: (torch.nn.modules.loss._Loss) an instance of torch.nn.modules.loss._Loss that has been
defined by the local model.
"""
assert isinstance(self.model, LocalDataModel), f"Model {self.model} is not a subclass of LocalModel."
return self.model.get_criterion()
[docs]
class FedProxLocalDataClient(FedProxClient, LocalDataClient): # type: ignore[misc]
"""Implementation of the FedProx client that uses a model with data stored locally."""
[docs]
def get_data_loaders(self, config: Config) -> tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]:
"""
Return the data loader for FedProx on model with data stored locally.
:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1)
batch_size = narrow_dict_type(config, "batch_size", int)
assert isinstance(self.model, LocalDataModel), f"Model {self.model} is not a subclass of LocalModel."
assert self.data_path is not None, "self.data_path is None."
return self.model.get_data_loaders(self.data_path, batch_size, sampler)
[docs]
class Client(Enum):
"""Enumeration of supported clients."""
FEDAVG = "FedAvg"
FEDPROX = "FedProx"
[docs]
def get_client_class(self) -> type[LocalDataClient]:
"""
Return the class for this client.
:return: (type[LocalDataClient]) A subclass of LocalDataClient corresponding to the this client.
:raises ValueError: if the client is not supported.
"""
if self == Client.FEDAVG:
return LocalDataClient
if self == Client.FEDPROX:
return FedProxLocalDataClient
raise ValueError(f"Client {self.value} not supported.")
[docs]
@classmethod
def list(cls) -> list[str]:
"""
List all the supported clients.
:return: (list[str]) a list of supported clients.
"""
return [client.value for client in Client]
[docs]
@classmethod
def list_by_strategy(cls, strategy: Strategy) -> List[str]:
"""
List all the supported clients given a strategy.
:param strategy: (Strategy) the strategy to find the supported clients.
:return: (list[str]) a list of supported clients for the given strategy.
"""
if strategy == Strategy.FEDAVG:
return [Client.FEDAVG.value]
if strategy == Strategy.FEDPROX:
return [Client.FEDPROX.value]
raise ValueError(f"Strategy {strategy} not supported.")