Source code for fl4health.servers.base_server

import datetime
from collections.abc import Callable, Sequence
from logging import DEBUG, ERROR, INFO, WARNING

import torch.nn as nn
from flwr.common import EvaluateRes, Parameters
from flwr.common.logger import log
from flwr.common.typing import Code, Config, GetParametersIns, Scalar
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.history import History
from flwr.server.server import EvaluateResultsAndFailures, FitResultsAndFailures, Server, evaluate_clients
from flwr.server.strategy import Strategy

from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.servers.polling import poll_clients
from fl4health.strategies.strategy_with_poll import StrategyWithPolling
from fl4health.utils.config import narrow_dict_type_and_set_attribute
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix
from fl4health.utils.parameter_extraction import get_all_model_parameters
from fl4health.utils.random import generate_hash
from fl4health.utils.typing import EvaluateFailures, FitFailures


[docs] class FlServer(Server):
[docs] def __init__( self, client_manager: ClientManager, fl_config: Config, strategy: Strategy | None = None, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: """ Base Server for the library to facilitate strapping additional/useful machinery to the base flwr server. Args: client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. fl_config (Config): This should be the configuration that was used to setup the federated training. In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. strategy (Strategy | None, optional): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server should send data to before and after each round. Defaults to None. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. server_name (str | None, optional): An optional string name to uniquely identify server. This name is also used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. """ super().__init__(client_manager=client_manager, strategy=strategy) self.fl_config = fl_config if checkpoint_and_state_module is not None: self.checkpoint_and_state_module = checkpoint_and_state_module else: # Define a default module that does nothing. self.checkpoint_and_state_module = BaseServerCheckpointAndStateModule( model=None, parameter_exchanger=None, model_checkpointers=None, state_checkpointer=None ) self.on_init_parameters_config_fn = on_init_parameters_config_fn self.server_name = server_name if server_name is not None else generate_hash() self.state_checkpoint_name = f"server_{self.server_name}_state.pt" self.accept_failures = accept_failures self.current_round: int self.history: History # Initialize reporters with server name information. self.reports_manager = ReportsManager(reporters) self.reports_manager.initialize(id=self.server_name) self._log_fl_config()
[docs] def update_before_fit(self, num_rounds: int, timeout: float | None) -> None: """ Hook method to allow the server to do some work before starting the fit process. In the base server, it is a no-op function, but it can be overridden in child classes for custom functionality. For example, the NnUNetServer class uses this method to ask a client to initialize the global nnunet plans if one is not provided in the config. This can only be done after the clients have started up and are ready to train. Args: num_rounds (int): The number of server rounds of FL to be performed timeout (float | None, optional): The server's timeout parameter. Useful if one is requesting information from a client. Defaults to None, which indicates indefinite timeout. """ pass
[docs] def report_centralized_eval(self, history: History, num_rounds: int) -> None: if len(history.losses_centralized) == 0: return # Parse and report history for loss and metrics on centralized validation set. for round in range(num_rounds): self.reports_manager.report( {"val - loss - centralized": history.losses_centralized[round][1]}, round + 1, ) round_metrics = {} for metric, vals in history.metrics_centralized.items(): round_metrics.update({metric: vals[round][1]}) self.reports_manager.report({"eval_round_metrics_centralized": round_metrics}, round + 1)
[docs] def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Runs federated learning for a number of rounds. Heavily based on the fit method from the base server provided by flower (flwr.server.server.Server) except that it is resilient to preemptions. It accomplishes this by checkpointing the server state each round. In the case of preemption, when the server is restarted it will load from the most recent checkpoint. Args: num_rounds (int): The number of rounds to perform federated learning. timeout (float | None): The timeout for clients to return results in a given FL round. Returns: tuple[History, float]: The first element of the tuple is a history object containing the losses and metrics computed during training and validation. The second element of the tuple is the elapsed time in seconds. """ # Attempt to load the server state if it exists. If the state checkpoint exists, update history, server # round and model accordingly state_load_success = self._load_server_state() if state_load_success: log(INFO, "Server state checkpoint successfully loaded.") else: log(INFO, "Initializing server state and global parameters") self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout) self.history = History() self.current_round = 1 if self.current_round == 1: log(INFO, "Evaluating initial parameters") res = self.strategy.evaluate(0, parameters=self.parameters) if res is not None: log( INFO, "initial parameters (loss, other metrics): %s, %s", res[0], res[1], ) self.history.add_loss_centralized(server_round=0, loss=res[0]) self.history.add_metrics_centralized(server_round=0, metrics=res[1]) # Run federated learning for num_rounds log(INFO, "FL starting") start_time = datetime.datetime.now() while self.current_round < (num_rounds + 1): # Train model and replace previous global model res_fit = self.fit_round(server_round=self.current_round, timeout=timeout) if res_fit: parameters_prime, fit_metrics, _ = res_fit # fit_metrics_aggregated if parameters_prime: self.parameters = parameters_prime self.history.add_metrics_distributed_fit(server_round=self.current_round, metrics=fit_metrics) # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(self.current_round, parameters=self.parameters) if res_cen is not None: loss_cen, metrics_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", self.current_round, loss_cen, metrics_cen, (datetime.datetime.now() - start_time).total_seconds(), ) self.history.add_loss_centralized(server_round=self.current_round, loss=loss_cen) self.history.add_metrics_centralized(server_round=self.current_round, metrics=metrics_cen) # Evaluate model on a sample of available clients res_fed = self.evaluate_round(server_round=self.current_round, timeout=timeout) if res_fed: loss_fed, evaluate_metrics_fed, _ = res_fed if loss_fed: self.history.add_loss_distributed(server_round=self.current_round, loss=loss_fed) self.history.add_metrics_distributed(server_round=self.current_round, metrics=evaluate_metrics_fed) self.current_round += 1 # Save checkpoint after training and testing self._save_server_state() # Bookkeeping end_time = datetime.datetime.now() elapsed_time = end_time - start_time log(INFO, "FL finished in %s", str(elapsed_time)) return self.history, elapsed_time.total_seconds()
[docs] def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Run federated learning for a number of rounds. This function also allows the server to perform some operations prior to fitting starting. This is useful, for example, if you need to communicate with the clients to initialize anything prior to FL starting (see nnunet server for an example) Args: num_rounds (int): Number of server rounds to run. timeout (float | None): The amount of time in seconds that the server will wait for results from the clients selected to participate in federated training. Returns: tuple[History, float]: The first element of the tuple is a history object containing the full set of FL training results, including things like aggregated loss and metrics. Tuple also contains the elapsed time in seconds for the round. """ start_time = datetime.datetime.now() self.reports_manager.report( { "fit_start": str(start_time), "host_type": "server", } ) self.update_before_fit(num_rounds, timeout) if self.checkpoint_and_state_module.state_checkpointer is not None: history, elapsed_time = self.fit_with_per_round_checkpointing(num_rounds, timeout) else: history, elapsed_time = super().fit(num_rounds, timeout) end_time = datetime.datetime.now() self.reports_manager.report( { "fit_elapsed_time": round((end_time - start_time).total_seconds()), "fit_end": str(end_time), "num_rounds": num_rounds, "host_type": "server", } ) # WARNING: This will not work with wandb. Wandb reporting must be done live. self.report_centralized_eval(history, num_rounds) return history, elapsed_time
[docs] def fit_round( self, server_round: int, timeout: float | None, ) -> tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None: """ This function is called at each round of federated training. The flow is generally the same as a flower server, where clients are sampled and client side training is requested from the clients that are chosen. This function simply adds a bit of logging, post processing of the results Args: server_round (int): Current round number of the FL training. Begins at 1 timeout (float | None): Time that the server should wait (in seconds) for responses from the clients. Defaults to None, which indicates indefinite timeout. Returns: tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None: The results of training on the client sit. The first set of parameters are the AGGREGATED parameters from the strategy. The second is a dictionary of AGGREGATED metrics. The third component holds the individual (non-aggregated) parameters, loss, and metrics for successful and unsuccessful client-side training. """ round_start = datetime.datetime.now() fit_round_results = super().fit_round(server_round, timeout) round_end = datetime.datetime.now() self.reports_manager.report( { "fit_round_start": str(round_start), "fit_round_end": str(round_end), "fit_round_time_elapsed": round((round_end - round_start).total_seconds()), }, server_round, ) if fit_round_results is not None: _, metrics, fit_results_and_failures = fit_round_results self.reports_manager.report({"fit_round_metrics": metrics}, server_round) failures = fit_results_and_failures[1] if fit_results_and_failures else None if failures and not self.accept_failures: self._log_client_failures(failures) self._terminate_after_unacceptable_failures(timeout) return fit_round_results
[docs] def shutdown(self) -> None: """ Currently just records termination of the server process and disconnects and reporters that need to be. """ self.reports_manager.report({"shutdown": str(datetime.datetime.now())}) self.reports_manager.shutdown()
[docs] def poll_clients_for_sample_counts(self, timeout: float | None) -> list[int]: """ Poll clients for sample counts from their training set, if you want to use this functionality your strategy needs to inherit from the StrategyWithPolling ABC and implement a configure_poll function. Args: timeout (float | None): Timeout for how long the server will wait for clients to report counts. If none then the server waits indefinitely. Returns: list[int]: The number of training samples held by each client in the pool of available clients. """ # Poll clients for sample counts, if you want to use this functionality your strategy needs to inherit from # the StrategyWithPolling ABC and implement a configure_poll function log(INFO, "Polling Clients for sample counts") assert isinstance(self.strategy, StrategyWithPolling) client_instructions = self.strategy.configure_poll(server_round=1, client_manager=self._client_manager) results, _ = poll_clients( client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout, ) sample_counts: list[int] = [ int(get_properties_res.properties["num_train_samples"]) for (_, get_properties_res) in results ] log(INFO, f"Polling complete: Retrieved {len(sample_counts)} sample counts") return sample_counts
[docs] def evaluate_round( self, server_round: int, timeout: float | None, ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # By default the checkpointing works off of the aggregated evaluation loss from each of the clients # NOTE: parameter aggregation occurs **before** evaluation, so the parameters held by the server have been # updated prior to this function being called. start_time = datetime.datetime.now() eval_round_results = self._evaluate_round(server_round, timeout) end_time = datetime.datetime.now() if eval_round_results: loss_aggregated, metrics_aggregated, (_, failures) = eval_round_results if failures and not self.accept_failures: self._log_client_failures(failures) self._terminate_after_unacceptable_failures(timeout) if loss_aggregated: self._maybe_checkpoint(loss_aggregated, metrics_aggregated, server_round) # Report evaluation results report_data = { "val - loss - aggregated": loss_aggregated, "round": server_round, "eval_round_start": str(start_time), "eval_round_end": str(end_time), "eval_round_time_elapsed": round((end_time - start_time).total_seconds()), } if self.fl_config.get("local_epochs", None) is not None: report_data["fit_epoch"] = server_round * self.fl_config["local_epochs"] elif self.fl_config.get("local_steps", None) is not None: report_data["fit_step"] = server_round * self.fl_config["local_steps"] self.reports_manager.report(report_data, server_round) if len(metrics_aggregated) > 0: self.reports_manager.report( {"eval_round_metrics_aggregated": metrics_aggregated}, server_round, ) return eval_round_results
def _log_fl_config(self) -> None: log(INFO, "FL Configuration:") if self.fl_config else log(INFO, "FL Config is Empty") for config_key, config_value in self.fl_config.items(): if not isinstance(config_value, bytes): log(INFO, f"Key: {config_key} Value: {config_value!r}") def _save_server_state(self) -> None: """ Save server checkpoint consisting of model, history, server round, metrics reporter and server name. This method can be overridden to add any necessary state to the checkpoint. The model will be injected into the ckpt state by the checkpoint module """ other_state_to_save = { "history": self.history, "current_round": self.current_round, "reports_manager": self.reports_manager, "server_name": self.server_name, } self.checkpoint_and_state_module.save_state( state_checkpoint_name=self.state_checkpoint_name, server_parameters=self.parameters, other_state=other_state_to_save, ) def _load_server_state(self) -> bool: """ Load server checkpoint consisting of model, history, server name, current round and metrics reporter. The method can be overridden to add any necessary state when loading the checkpoint. """ # Attempt to load the server state if it exists. This variable will be None if it does not. server_state = self.checkpoint_and_state_module.maybe_load_state(self.state_checkpoint_name) if server_state is None: return False narrow_dict_type_and_set_attribute(self, server_state, "server_name", "server_name", str) narrow_dict_type_and_set_attribute(self, server_state, "current_round", "current_round", int) narrow_dict_type_and_set_attribute(self, server_state, "reports_manager", "reports_manager", ReportsManager) narrow_dict_type_and_set_attribute(self, server_state, "history", "history", History) narrow_dict_type_and_set_attribute( self, server_state, "model", "parameters", nn.Module, func=get_all_model_parameters ) # Needed for when _hydrate_model_for_checkpointing is called narrow_dict_type_and_set_attribute(self, server_state, "model", "server_model", nn.Module) self.parameters = get_all_model_parameters(server_state["model"]) return True def _terminate_after_unacceptable_failures(self, timeout: float | None) -> None: assert not self.accept_failures # First we shutdown all clients involved in the FL training/evaluation if they can be. self.disconnect_all_clients(timeout=timeout) # Throw an exception alerting the user to failures on the client-side causing termination self.shutdown() raise ValueError( f"The server encountered failures from the clients and accept_failures is set to {self.accept_failures}" ) def _log_client_failures(self, failures: FitFailures | EvaluateFailures) -> None: log( ERROR, f"There were {len(failures)} failures in the fitting process. This will result in termination of " "the FL process", ) for failure in failures: if isinstance(failure, BaseException): log( ERROR, "An exception was returned instead of any failed results. As such the client ID is unknown. " "Please check the client logs to determine which failed.\n" f"The exception thrown was {repr(failure)}", ) else: client_proxy, _ = failure log( ERROR, f"Client {client_proxy.cid} failed but did not return an exception. Partial results were received", ) def _maybe_checkpoint( self, loss_aggregated: float, metrics_aggregated: dict[str, Scalar], server_round: int, ) -> None: """ This function simply runs the maybe_checkpoint functionality of the checkpoint_and_state_module. If additional functionality is desired, this function may be overridden. Args: loss_aggregated (float): aggregated loss value that can be used to determine whether to checkpoint metrics_aggregated (dict[str, Scalar]): aggregated metrics from each of the clients for checkpointing server_round (int): What round of federated training we're on. This is just for logging purposes. """ self.checkpoint_and_state_module.maybe_checkpoint(self.parameters, loss_aggregated, metrics_aggregated) def _get_initial_parameters(self, server_round: int, timeout: float | None) -> Parameters: """ Get initial parameters from one of the available clients. This function is the same as the parent function in the flower server class except that we make use of the on_parameter_initialization_config_fn to provide a non-empty config to a client when requesting parameters from which to initialize all other clients. NOTE: The default behavior of flower servers is to simply send over a blank config, but this is insufficient for certain uses, where the client requires additional information from the server. This is needed, for example in nnUnet-based Servers. An issue has been logged with flower: https://github.com/adap/flower/issues/3770 """ # Server-side parameter initialization parameters: Parameters | None = self.strategy.initialize_parameters(client_manager=self._client_manager) if parameters is not None: log(INFO, "Using initial global parameters provided by strategy") return parameters # Get initial parameters from one of the clients log(INFO, "Requesting initial parameters from one random client") random_client = self._client_manager.sample(1)[0] if self.on_init_parameters_config_fn is None: # An empty configuration is the default for Flower servers ins = GetParametersIns(config={}) else: ins = GetParametersIns(config=self.on_init_parameters_config_fn(server_round)) get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout, group_id=server_round) if get_parameters_res.status.code == Code.OK: log(INFO, "Received initial parameters from one random client") else: log( WARNING, "Failed to receive initial parameters from the client. Empty initial parameters will be used.", ) return get_parameters_res.parameters def _unpack_metrics( self, results: list[tuple[ClientProxy, EvaluateRes]] ) -> tuple[list[tuple[ClientProxy, EvaluateRes]], list[tuple[ClientProxy, EvaluateRes]]]: val_results = [] test_results = [] for client_proxy, eval_res in results: val_metrics = { k: v for k, v in eval_res.metrics.items() if not k.startswith(MetricPrefix.TEST_PREFIX.value) } test_metrics = {k: v for k, v in eval_res.metrics.items() if k.startswith(MetricPrefix.TEST_PREFIX.value)} if len(test_metrics) > 0: assert TEST_LOSS_KEY in test_metrics and TEST_NUM_EXAMPLES_KEY in test_metrics, ( f"'{TEST_NUM_EXAMPLES_KEY}' and '{TEST_LOSS_KEY}' keys must be present in " "test_metrics dictionary for aggregation" ) # Remove loss and num_examples from test_metrics if they exist test_loss = float(test_metrics.pop(TEST_LOSS_KEY)) test_num_examples = int(test_metrics.pop(TEST_NUM_EXAMPLES_KEY)) test_eval_res = EvaluateRes(eval_res.status, test_loss, test_num_examples, test_metrics) test_results.append((client_proxy, test_eval_res)) val_eval_res = EvaluateRes(eval_res.status, eval_res.loss, eval_res.num_examples, val_metrics) val_results.append((client_proxy, val_eval_res)) return val_results, test_results def _handle_result_aggregation( self, server_round: int, results: list[tuple[ClientProxy, EvaluateRes]], failures: list[tuple[ClientProxy, EvaluateRes] | BaseException], ) -> tuple[float | None, dict[str, Scalar]]: val_results, test_results = self._unpack_metrics(results) # Aggregate the validation results val_aggregated_result: tuple[ float | None, dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, val_results, failures) val_loss_aggregated, val_metrics_aggregated = val_aggregated_result # Aggregate the test results if they are present if len(test_results) > 0: test_aggregated_result: tuple[ float | None, dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, test_results, failures) test_loss_aggregated, test_metrics_aggregated = test_aggregated_result for key, value in test_metrics_aggregated.items(): val_metrics_aggregated[key] = value if test_loss_aggregated is not None: val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated"] = test_loss_aggregated return val_loss_aggregated, val_metrics_aggregated def _evaluate_round( self, server_round: int, timeout: float | None, ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_evaluate( server_round=server_round, parameters=self.parameters, client_manager=self._client_manager, ) if not client_instructions: log(INFO, "evaluate_round %s: no clients selected, cancel", server_round) return None log( DEBUG, "evaluate_round %s: strategy sampled %s clients (out of %s)", server_round, len(client_instructions), self._client_manager.num_available(), ) # Collect `evaluate` results from all clients participating in this round # flwr sets group_id to server_round by default, so we follow that convention results, failures = evaluate_clients( client_instructions, max_workers=self.max_workers, timeout=timeout, group_id=server_round, ) log( DEBUG, "evaluate_round %s received %s results and %s failures for Validation", server_round, len(results), len(failures), ) val_loss_aggregated, val_metrics_aggregated = self._handle_result_aggregation(server_round, results, failures) return val_loss_aggregated, val_metrics_aggregated, (results, failures)