Source code for florist.api.servers.utils

"""Utilities functions and definitions for starting a server."""

from functools import partial
from typing import Callable, Dict, Union

from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.server.base_server import FlServer
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from flwr.common.parameter import ndarrays_to_parameters
from flwr.server.strategy import FedAvg
from torch import nn


FitConfigFn = Callable[[int], Dict[str, Union[bool, bytes, float, int, str]]]


[docs] def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: """ Return a dictionary used to configure the server's fit function. :param batch_size: (int) the size of the batch of samples. :param local_epochs: (int) the number of local epochs the clients will run. :param current_server_round: (int) the current server round :return: (Dict[str, int]) A dictionary to the used at the config for the fit function. """ return { "batch_size": batch_size, "current_server_round": current_server_round, "local_epochs": local_epochs, }
[docs] def get_server( model: nn.Module, reporters: list[BaseReporter], fit_config: Callable[[int, int, int], Dict[str, int]] = fit_config, n_clients: int = 2, batch_size: int = 8, local_epochs: int = 1, ) -> FlServer: """ Return a server instance with FedAvg aggregation strategy. :param model: (torch.nn.Model) the model the server and clients will be using. :param fit_config: (Callable[[int, int, int], Dict[str, int]]) the function to configure the fit method. :param n_clients: (int) the number of clients that will participate on training. Optional, default is 2. :param batch_size: (int) the size of the batch of samples. Optional, default is 8. :param local_epochs: (int) the number of local epochs the clients will run. Optional, default is 1. :param reporters: (list[fl4health.reporting.base_reporter.BaseReporter]) An optional metrics reporter instance. Default is None. :return: (fl4health.server.base_server.FlServer) An instance of FlServer with FedAvg as strategy. """ fit_config_fn: FitConfigFn = partial(fit_config, batch_size, local_epochs) # type: ignore initial_model_parameters = ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()]) strategy = FedAvg( min_fit_clients=n_clients, min_evaluate_clients=n_clients, min_available_clients=n_clients, on_fit_config_fn=fit_config_fn, on_evaluate_config_fn=fit_config_fn, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=initial_model_parameters, ) client_manager = SimpleClientManager() return FlServer(strategy=strategy, client_manager=client_manager, reporters=reporters)