classPyTorchFLTask(BaseFLTask):_client:NumPyClient=PrivateAttr()_trainer:Callable=PrivateAttr()_trainer_spec:TrainerSignatureSpec=PrivateAttr()_tester:Callable=PrivateAttr()_tester_spec:TesterSignatureSpec=PrivateAttr()def__init__(self,trainer:Callable,trainer_spec:TrainerSignatureSpec,tester:Callable,tester_spec:TesterSignatureSpec,**kwargs:Any,)->None:iftrainer_spec.net_parameter!=tester_spec.net_parameter:msg=("`trainer`'s model parameter name is not the same as that for `tester`. ""Will use the name supplied in `trainer`.")warnings.warn(msg,UnequalNetParamWarning)super().__init__(**kwargs)self._trainer=trainerself._trainer_spec=trainer_specself._tester=testerself._tester_spec=tester_spec@propertydeftraining_loop(self)->Callable:returnself._trainer@classmethoddeffrom_trainer_and_tester(cls,trainer:Callable,tester:Callable)->Self:# extract trainer spectry:trainer_spec:TrainerSignatureSpec=getattr(trainer,"__fl_task_trainer_config")exceptAttributeError:msg="Cannot extract `TrainerSignatureSpec` from supplied `trainer`."raiseMissingTrainerSpec(msg)# extract tester spectry:tester_spec:TesterSignatureSpec=getattr(tester,"__fl_task_tester_config")exceptAttributeError:msg=("Cannot extract `TesterSignatureSpec` from supplied `tester`.")raiseMissingTesterSpec(msg)returncls(trainer=trainer,trainer_spec=trainer_spec,tester=tester,tester_spec=tester_spec,)@classmethoddeffrom_configs(cls,trainer_cfg:Any,tester_cfg:Any)->Any:returnsuper().from_configs(trainer_cfg,tester_cfg)defserver(self,strategy:Strategy|None=None,client_manager:ClientManager|None=None,**kwargs:Any,)->PyTorchFlowerServer|None:ifstrategyisNone:ifself._trainer_spec.net_parameternotinkwargs:msg=f"Please pass in a model using the model param name {self._trainer_spec.net_parameter}."raiseMissingRequiredNetParam(msg)model=kwargs.pop(self._trainer_spec.net_parameter)ndarrays=_get_weights(model)parameters=ndarrays_to_parameters(ndarrays)strategy=FedAvg(fraction_evaluate=1.0,initial_parameters=parameters,)ifclient_managerisNone:client_manager=SimpleClientManager()returnPyTorchFlowerServer(client_manager=client_manager,strategy=strategy)defclient(self,**kwargs:Any)->Client|None:# validate kwargsifself._trainer_spec.net_parameternotinkwargs:msg=f"Please pass in a model using the model param name {self._trainer_spec.net_parameter}."raiseMissingRequiredNetParam(msg)# build bundlenet=kwargs.pop(self._trainer_spec.net_parameter)trainloader=kwargs.pop(self._trainer_spec.train_data_param)valloader=kwargs.pop(self._trainer_spec.val_data_param)bundle=BaseFLTaskBundle(net=net,trainloader=trainloader,valloader=valloader,extra_train_kwargs=kwargs,extra_test_kwargs={},# TODO make this functional or get rid of ittrainer=self._trainer,tester=self._tester,)returnPyTorchFlowerClient(task_bundle=bundle)defsimulate(self,num_clients:int,**kwargs:Any)->Any:raiseNotImplementedError