"""Definitions for the strategies, strategy enumeration and server constructors."""
from __future__ import annotations
from enum import Enum
from functools import partial
from typing import Any, Callable, TypeAlias
import torch
from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer
from fl4health.servers.base_server import FlServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from flwr.common import Scalar
from flwr.common.parameter import ndarrays_to_parameters
from flwr.server.strategy import FedAvg
from florist.api.servers.config_parsers import ConfigParser
GetServerFunction: TypeAlias = Callable[[torch.nn.Module, int, list[BaseReporter], dict[str, Any]], FlServer]
ConfigFn: TypeAlias = Callable[[int], dict[str, Scalar]]
[docs]
class Strategy(Enum):
"""The strategies that can be picked for training."""
FEDAVG = "FedAvg"
FEDPROX = "FedProx"
[docs]
def get_config_parser(self) -> ConfigParser:
"""
Return the config parser for this strategy.
:return: (ConfigParser) An instance of ConfigParser for the corresponding strategy.
:raises ValueError: if the strategy is not supported.
"""
if self == Strategy.FEDAVG:
return ConfigParser.BASIC
if self == Strategy.FEDPROX:
return ConfigParser.FEDPROX
raise ValueError(f"Strategy {self.value} not supported.")
[docs]
def get_server_factory(self) -> "ServerFactory":
"""
Return the server factory instance for this strategy.
:return: (type[AbstractServerFactory]) A ServerFactory instance that can be used to construct
the FL server for the given strategy.
:raises ValueError: if the client is not supported.
"""
if self == Strategy.FEDAVG:
return ServerFactory(get_server_function=get_fedavg_server)
if self == Strategy.FEDPROX:
return ServerFactory(get_server_function=get_fedprox_server)
raise ValueError(f"Strategy {self.value} not supported.")
[docs]
@classmethod
def list(cls) -> list[str]:
"""
List all the supported strategies.
:return: (list[str]) a list of supported strategies.
"""
return [strategy.value for strategy in Strategy]
[docs]
class ServerFactory:
"""Factory class that will provide the server constructor."""
[docs]
def __init__(self, get_server_function: GetServerFunction):
"""
Initialize a ServerFactory.
:param get_server_function: (GetServerFunction) The function that will be used to produce
the server constructor.
"""
self.get_server_function = get_server_function
[docs]
def get_server_constructor(
self,
model: torch.nn.Module,
n_clients: int,
reporters: list[BaseReporter],
server_config: dict[str, Scalar],
) -> Callable[[Any], FlServer]:
"""
Make the server constructor based on the self.get_server_function.
:param model: (torch.nn.Model) The model object.
:param n_clients: (int) The number of clients participating in the FL training.
:param reporters: (list[BaseReporter]) A list of reporters to be passed to the FL server.
:param server_config: (dict[str, Any]) A dictionary with the server configuration values.
:return: (Callable[[Any], FlServer]) A callable function that will construct an FL server.
"""
return partial(self.get_server_function, model, n_clients, reporters, server_config)
def __eq__(self, other: object) -> bool:
"""
Check if the self instance is equal to the given other instance.
:param other: (Any) The other instance to compare it to.
:return: (bool) True if the instances are the same, False otherwise.
"""
if not isinstance(other, self.__class__):
return NotImplemented
if self.get_server_function != other.get_server_function: # noqa: SIM103
return False
return True
[docs]
def fit_config_function(server_config: dict[str, Scalar], current_server_round: int) -> dict[str, Scalar]:
"""
Produce the fit config dictionary.
:param server_config: (dict[str, Any]) A dictionary with the server configuration.
:param current_server_round: (int) The current server round.
"""
return {
**server_config,
"current_server_round": current_server_round,
}
[docs]
def get_fedavg_server(
model: torch.nn.Module,
n_clients: int,
reporters: list[BaseReporter],
server_config: dict[str, Scalar],
) -> FlServer:
"""
Return a server with FedAvg strategy.
:param model: (torch.nn.Module) The torch.nn.Module instance for the model.
:param n_clients: (int) the number of clients participating in the FL training.
:param reporters: (list[BaseReporter]) A list of reporters to be passed to the FL server.
:param server_config: (dict[str, Any]) A dictionary with the server configuration values.
:return: (FlServer) An FlServer instance configured with FedAvg strategy.
"""
config_fn: ConfigFn = partial(fit_config_function, server_config)
initial_model_parameters = ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()])
strategy = FedAvg(
min_fit_clients=n_clients,
min_evaluate_clients=n_clients,
min_available_clients=n_clients,
on_fit_config_fn=config_fn,
on_evaluate_config_fn=config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=initial_model_parameters,
)
client_manager = SimpleClientManager()
return FlServer(strategy=strategy, client_manager=client_manager, reporters=reporters, fl_config=server_config)
[docs]
def get_fedprox_server(
model: torch.nn.Module,
n_clients: int,
reporters: list[BaseReporter],
server_config: dict[str, Scalar],
) -> FlServer:
"""
Return a server with FedProx strategy.
:param model: (nn.Module) The torch.nn.Module instance for the model.
:param n_clients: (int) the number of clients participating in the FL training.
:param reporters: (list[BaseReporter]) A list of reporters to be passed to the FL server.
:param server_config: (dict[str, Any]) A dictionary with the server configuration values.
:return: (FlServer) An FlServer instance configured with FedProx strategy.
"""
config_fn: ConfigFn = partial(fit_config_function, server_config)
initial_model_parameters = ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()])
strategy = FedAvgWithAdaptiveConstraint(
min_fit_clients=n_clients,
min_evaluate_clients=n_clients,
# Server waits for min_available_clients before starting FL rounds
min_available_clients=n_clients,
on_fit_config_fn=config_fn,
# We use the same fit config function, as nothing changes for eval
on_evaluate_config_fn=config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=initial_model_parameters,
adapt_loss_weight=server_config["adapt_proximal_weight"],
initial_loss_weight=server_config["initial_proximal_weight"],
loss_weight_delta=server_config["proximal_weight_delta"],
loss_weight_patience=server_config["proximal_weight_patience"],
)
client_manager = SimpleClientManager()
return FedProxServer(
client_manager=client_manager, strategy=strategy, reporters=reporters, fl_config=server_config
)