Source code for florist.api.routes.server.training
"""FastAPI routes for training."""importasyncioimportloggingfromjsonimportJSONDecodeErrorfromthreadingimportThreadfromtypingimportAny,ListimportrequestsfromfastapiimportAPIRouter,Depends,Requestfromfastapi.responsesimportJSONResponsefrommotor.motor_asyncioimportAsyncIOMotorClientfromflorist.api.clients.clientsimportClientfromflorist.api.clients.optimizersimportOptimizerfromflorist.api.db.configimportDatabaseConfigfromflorist.api.db.server_entitiesimportClientInfo,Job,JobStatusfromflorist.api.launchers.localimportlaunch_local_serverfromflorist.api.models.modelsimportModelfromflorist.api.monitoring.metricsimportget_from_redis,get_subscriber,wait_for_metricfromflorist.api.routes.server.authimportcheck_default_user_token,get_client_tokenfromflorist.api.servers.config_parsersimportConfigParserrouter=APIRouter()LOGGER=logging.getLogger("uvicorn.error")START_CLIENT_API="api/client/start"CHECK_CLIENT_STATUS_API="api/client/check_status"
[docs]@router.post("/start",dependencies=[Depends(check_default_user_token)])asyncdefstart(job_id:str,request:Request)->JSONResponse:""" Start FL training for a job id by starting a FL server and its clients. :param job_id: (str) The id of the Job record in the DB which contains the information necessary to start training. :param request: (fastapi.Request) the FastAPI request object. :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and the clients in the format below. The UUIDs can be used to pull metrics from Redis. { "server_uuid": <client uuid>, "client_uuids": [<client_uuid_1>, <client_uuid_2>, ..., <client_uuid_n>], } If not successful, returns the appropriate error code with a JSON with the format below: {"error": <error message>} """job=Nonetry:job=awaitJob.find_by_id(job_id,request.app.database)assertjobisnotNone,f"Job with id {job_id} not found."assertjob.status==JobStatus.NOT_STARTED,f"Job status ({job.status.value}) is not NOT_STARTED"awaitjob.set_status(JobStatus.IN_PROGRESS,request.app.database)assertjob.modelisnotNone,"Missing Job information: model"assertjob.strategyisnotNone,"Missing Job information: strategy"assertjob.optimizerisnotNone,"Missing Job information: optimizer"assertjob.server_configisnotNone,"Missing Job information: server_config"assertjob.clientisnotNone,"Missing Job information: client"assertjob.client.valueinClient.list_by_strategy(job.strategy),(f"Client {job.client} not valid for strategy {job.strategy}.")assertjob.clients_infoisnotNoneandlen(job.clients_info)>0,"Missing Job information: clients_info"assertjob.server_addressisnotNone,"Missing Job information: server_address"assertjob.redis_addressisnotNone,"Missing Job information: redis_address"model_class=job.model.get_model_class()config_parser=job.strategy.get_config_parser()server_factory=job.strategy.get_server_factory()try:config_parser_class=ConfigParser.class_for_parser(config_parser)server_config=config_parser_class.parse(job.server_config)exceptJSONDecodeErroraserr:raiseAssertionError("server_config is not a valid json string.")fromerr# Start the serverserver_uuid,server_process,server_log_file_path=launch_local_server(model=model_class(),server_config=server_config,server_factory=server_factory,server_address=job.server_address,n_clients=len(job.clients_info),redis_address=job.redis_address,)awaitjob.set_server_log_file_path(server_log_file_path,request.app.database)wait_for_metric(server_uuid,"fit_start",job.redis_address,logger=LOGGER)# Start the clientsclient_uuids:List[str]=[]foriinrange(len(job.clients_info)):client_info=job.clients_info[i]uuid=_start_client(job.server_address,job.client,job.model,job.optimizer,client_info,request)client_uuids.append(uuid)awaitjob.set_uuids(server_uuid,client_uuids,request.app.database)awaitjob.set_server_pid(str(server_process.pid),request.app.database)# Start the server training listener and client training listeners as threads to update# the job's metrics and status once the training is doneserver_listener_thread=Thread(target=asyncio.run,args=(server_training_listener(job),))server_listener_thread.daemon=Trueserver_listener_thread.start()forclient_infoinjob.clients_info:client_listener_thread=Thread(target=asyncio.run,args=(client_training_listener(job,client_info),))client_listener_thread.daemon=Trueclient_listener_thread.start()# Return the UUIDsreturnJSONResponse({"server_uuid":server_uuid,"client_uuids":client_uuids})exceptAssertionErroraserr:ifjobisnotNone:awaitjob.set_status(JobStatus.FINISHED_WITH_ERROR,request.app.database)awaitjob.set_error_message(str(err),request.app.database)returnJSONResponse(content={"error":str(err)},status_code=400)exceptExceptionasex:LOGGER.exception(ex)ifjobisnotNone:awaitjob.set_status(JobStatus.FINISHED_WITH_ERROR,request.app.database)awaitjob.set_error_message(str(ex),request.app.database)returnJSONResponse({"error":str(ex)},status_code=500)
[docs]asyncdefclient_training_listener(job:Job,client_info:ClientInfo)->None:""" Listen to the Redis' channel that reports updates on the training process of a FL client. Keeps consuming updates to the channel until it finds `shutdown` in the client metrics. :param job: (Job) The job that has this client's metrics. :param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to. """LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}")assertclient_info.uuidisnotNone,"client_info.uuid is None."db_client:AsyncIOMotorClient[Any]=AsyncIOMotorClient(DatabaseConfig.mongodb_uri)database=db_client[DatabaseConfig.mongodb_db_name]# check if training has already finished before start listeningclient_metrics=get_from_redis(client_info.uuid,client_info.redis_address)LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}")ifclient_metricsisnotNone:LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}")awaitjob.set_client_metrics(client_info.uuid,client_metrics,database)LOGGER.info(f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated.")if"shutdown"inclient_metrics:db_client.close()returnsubscriber=get_subscriber(client_info.uuid,client_info.redis_address)# TODO add a max retries mechanism, maybe?formessageinsubscriber.listen():# type: ignore[no-untyped-call]ifmessage["type"]=="message":# The contents of the message do not matter, we just use it to get notifiedclient_metrics=get_from_redis(client_info.uuid,client_info.redis_address)LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}")ifclient_metricsisnotNone:LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}")awaitjob.set_client_metrics(client_info.uuid,client_metrics,database)LOGGER.info(f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated.")if"shutdown"inclient_metrics:db_client.close()returndb_client.close()
[docs]asyncdefserver_training_listener(job:Job)->None:""" Listen to the Redis' channel that reports updates on the training process of a FL server. Keeps consuming updates to the channel until it finds `fit_end` in the server metrics, then closes the job with FINISHED_SUCCESSFULLY and saves both the clients and server's metrics to the job in the database. :param job: (Job) The job with the server_uuid to listen to. """LOGGER.info(f"Starting listener for server messages from job {job.id} at channel {job.server_uuid}")assertjob.server_uuidisnotNone,"job.server_uuid is None."assertjob.redis_addressisnotNone,"job.redis_address is None."db_client:AsyncIOMotorClient[Any]=AsyncIOMotorClient(DatabaseConfig.mongodb_uri)database=db_client[DatabaseConfig.mongodb_db_name]# check if training has already finished before start listeningserver_metrics=get_from_redis(job.server_uuid,job.redis_address)LOGGER.debug(f"Server listener: Current metrics for job {job.id}: {server_metrics}")ifserver_metricsisnotNone:LOGGER.info(f"Server listener: Updating server metrics for job {job.id}")awaitjob.set_server_metrics(server_metrics,database)LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.")if"fit_end"inserver_metrics:LOGGER.info(f"Server listener: Training finished for job {job.id}")awaitjob.set_status(JobStatus.FINISHED_SUCCESSFULLY,database)LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.")db_client.close()returnsubscriber=get_subscriber(job.server_uuid,job.redis_address)# TODO add a max retries mechanism, maybe?formessageinsubscriber.listen():# type: ignore[no-untyped-call]ifmessage["type"]=="message":# The contents of the message do not matter, we just use it to get notifiedserver_metrics=get_from_redis(job.server_uuid,job.redis_address)LOGGER.debug(f"Server listener: Message received for job {job.id}. Metrics: {server_metrics}")ifserver_metricsisnotNone:LOGGER.info(f"Server listener: Updating server metrics for job {job.id}")awaitjob.set_server_metrics(server_metrics,database)LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.")if"fit_end"inserver_metrics:LOGGER.info(f"Server listener: Training finished for job {job.id}")awaitjob.set_status(JobStatus.FINISHED_SUCCESSFULLY,database)LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.")db_client.close()returndb_client.close()
def_start_client(server_address:str,client:Client,model:Model,optimizer:Optimizer,client_info:ClientInfo,request:Request,)->str:""" Start a client. :param server_address: (str) the address of the server the client needs to report to :param client_info: (ClientInfo) an instance of ClientInfo with the information needed to start the client :return (Tuple[str, str, str]): A tuple containing two values: the client's UUID and PID """parameters={"server_address":server_address,"client":client.value,"model":model.value,"optimizer":optimizer.value,"data_path":client_info.data_path,"redis_address":client_info.redis_address,}token=get_client_token(client_info,request)response=requests.get(url=f"http://{client_info.service_address}/{START_CLIENT_API}",params=parameters,headers={"Authorization":f"Bearer {token.access_token}"},)json_response=response.json()ifresponse.status_code!=200:raiseException(f"Client response returned {response.status_code}. Response: {json_response}")if"uuid"notinjson_response:raiseException(f"Client response did not return a UUID. Response: {json_response}")ifnotisinstance(json_response["uuid"],str):raiseException(f"Client UUID is not a string: {json_response['uuid']}")returnjson_response["uuid"]