[docs]def__init__(self,*,client_manager:ClientManager,strategy:Strategy|None=None,server_model:nn.Module|None=None,checkpointer:LatestTorchModuleCheckpointer|None=None,parameter_exchanger:ParameterExchanger|None=None,reporters:Sequence[BaseReporter]|None=None,server_name:str|None=None,)->None:""" ``ModelMergeServer`` provides functionality to fetch client weights, perform a simple average, redistribute to clients for evaluation. Optionally can perform server side evaluation as well. Args: client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. strategy (Strategy | None, optional): The aggregation strategy to be used by the server to handle client updates sent by the participating clients. Must be ``ModelMergeStrategy``. checkpointer (LatestTorchCheckpointer | None, optional): To be provided if the server should perform server side checkpointing on the merged model. If none, then no server-side checkpointing is performed. Defaults to None. server_model (nn.Module | None): Optional model to be hydrated with parameters from model merge if doing server side checkpointing. Must only be provided if checkpointer is also provided. Defaults to None. parameter_exchanger (ExchangerType | None, optional): A parameter exchanger used to facilitate server-side model checkpointing if a checkpointer has been defined. If not provided then checkpointing will not be done unless the ``_hydrate_model_for_checkpointing`` function is overridden. Because the server only sees numpy arrays, the parameter exchanger is used to insert the numpy arrays into a provided model. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. server_name (str | None): An optional string name to uniquely identify server. """assertisinstance(strategy,ModelMergeStrategy)assert(server_modelisNoneandcheckpointerisNoneandparameter_exchangerisNone)or(server_modelisnotNoneandcheckpointerisnotNoneandparameter_exchangerisnotNone)super().__init__(client_manager=client_manager,strategy=strategy)self.checkpointer=checkpointerself.server_model=server_modelself.parameter_exchanger=parameter_exchangerself.server_name=server_nameifserver_nameisnotNoneelsegenerate_hash()# Initialize reporters with server name information.self.reports_manager=ReportsManager(reporters)self.reports_manager.initialize(id=self.server_name)
[docs]deffit(self,num_rounds:int,timeout:float|None)->tuple[History,float]:""" Performs a fit round in which the local client weights are evaluated on their test set, uploaded to the server and averaged, then redistributed to clients for evaluation. Optionally, can perform evaluation of the merged model on the server side as well. Args: num_rounds (int): Not used. timeout (float | None): Timeout in seconds that the server should wait for the clients to respond. If none, then it will wait for the minimum number to respond indefinitely. Returns: tuple[History, float]: The first element of the tuple is a History object containing the aggregated metrics returned from the clients. Tuple also contains elapsed time in seconds for round. """self.reports_manager.report({"host_type":"server","fit_start":datetime.datetime.now()})history=History()# Run Federated Model Merginglog(INFO,"Federated Model Merging Starting")start_time=timeit.default_timer()res_fit=self.fit_round(server_round=1,timeout=timeout,)ifres_fitisnotNone:parameters_prime,fit_metrics,_=res_fit# fit_metrics_aggregatedifparameters_prime:self.parameters=parameters_primehistory.add_metrics_distributed_fit(server_round=1,metrics=fit_metrics)else:log(WARNING,"Federated Model Merging Failed")res_fed=self.evaluate_round(server_round=1,timeout=timeout)ifres_fedisnotNone:# ignore loss as one is not defined in model merging_,evaluate_metrics_fed,_=res_fedifevaluate_metrics_fedisnotNone:history.add_metrics_distributed(server_round=1,metrics=evaluate_metrics_fed)# Evaluate model using strategy implementationres_cen=self.strategy.evaluate(1,parameters=self.parameters)ifres_cenisnotNone:# ignore loss as one is not defined in model merging_,metrics_cen=res_cenhistory.add_metrics_centralized(server_round=1,metrics=metrics_cen)# Checkpoint based on dummy loss aggregated and metrics aggregated since# we are using LatestTorchCheckpointer and will always checkpoint if# server_model, parameter_exchanger and checkpointer are not Noneself._maybe_checkpoint(loss_aggregated=0.0,metrics_aggregated={},server_round=1)self.reports_manager.report(data={"fit_end":datetime.datetime.now(),"metrics_centralized":history.metrics_centralized,"losses_centralized":history.losses_centralized,"host_type":"server",})# Bookkeepingend_time=timeit.default_timer()elapsed=end_time-start_timelog(INFO,"Federated Model Merging Finished in %s",elapsed)returnhistory,elapsed
def_hydrate_model_for_checkpointing(self)->nn.Module:""" Method used for converting server parameters into a torch model that can be checkpointed. Returns: nn.Module: Torch model to be checkpointed by a torch checkpointer. """assertself.server_modelisnotNone,("Model hydration has been called but no server_model is defined to hydrate. The functionality of ""_hydrate_model_for_checkpointing can be overridden if checkpointing without a torch architecture is ""possible and desired")assertself.parameter_exchangerisnotNone,("Model hydration has been called but no parameter_exchanger is defined to hydrate. The functionality of ""_hydrate_model_for_checkpointing can be overridden if checkpointing without a parameter exchanger is ""possible and desired")model_ndarrays=parameters_to_ndarrays(self.parameters)self.parameter_exchanger.pull_parameters(model_ndarrays,self.server_model)returnself.server_modeldef_maybe_checkpoint(self,loss_aggregated:float,metrics_aggregated:dict[str,Scalar],server_round:int)->None:""" Method to checkpoint merged model on server side if the checkpointer, ``server_model`` and ``parameter_exchanger`` provided at initialization are all not None. Args: loss_aggregated (float): Not used. metrics_aggregated (dict[str, Scalar]): Not used. server_round (int): Not used. """ifself.checkpointerandself.server_modelandself.parameter_exchanger:model=self._hydrate_model_for_checkpointing()self.checkpointer.maybe_checkpoint(model,loss_aggregated,metrics_aggregated)else:attribute_dict={"checkpointer":self.checkpointer,"server_model":self.server_model,"parameter_exchanger":self.parameter_exchanger,}error_str=" and ".join([keyforkey,valinattribute_dict.items()ifvalisNone])log(WARNING,f"""All of checkpointer, server_model and parameter_exchanger must be None to perform server-side checkpointing. {error_str} is None""",)