Source code for florist.api.routes.server.training

"""FastAPI routes for training."""

import logging
from json import JSONDecodeError
from typing import Any, Dict, List

import requests
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi.responses import JSONResponse
from pymongo.database import Database

from florist.api.db.entities import ClientInfo, Job, JobStatus
from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
from florist.api.servers.launch import launch_local_server


router = APIRouter()

LOGGER = logging.getLogger("uvicorn.error")

START_CLIENT_API = "api/client/start"
CHECK_CLIENT_STATUS_API = "api/client/check_status"


[docs] @router.post("/start") async def start(job_id: str, request: Request, background_tasks: BackgroundTasks) -> JSONResponse: """ Start FL training for a job id by starting a FL server and its clients. :param job_id: (str) The id of the Job record in the DB which contains the information necessary to start training. :param request: (fastapi.Request) the FastAPI request object. :param background_tasks: (BackgroundTasks) A BackgroundTasks instance to launch the training listener, which will update the progress of the training job. :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and the clients in the format below. The UUIDs can be used to pull metrics from Redis. { "server_uuid": <client uuid>, "client_uuids": [<client_uuid_1>, <client_uuid_2>, ..., <client_uuid_n>], } If not successful, returns the appropriate error code with a JSON with the format below: {"error": <error message>} """ job = None try: job = await Job.find_by_id(job_id, request.app.database) assert job is not None, f"Job with id {job_id} not found." assert job.status == JobStatus.NOT_STARTED, f"Job status ({job.status.value}) is not NOT_STARTED" await job.set_status(JobStatus.IN_PROGRESS, request.app.database) if job.config_parser is None: job.config_parser = ConfigParser.BASIC assert job.model is not None, "Missing Job information: model" assert job.server_config is not None, "Missing Job information: server_config" assert job.clients_info is not None and len(job.clients_info) > 0, "Missing Job information: clients_info" assert job.server_address is not None, "Missing Job information: server_address" assert job.redis_host is not None, "Missing Job information: redis_host" assert job.redis_port is not None, "Missing Job information: redis_port" try: config_parser = ConfigParser.class_for_parser(job.config_parser) server_config = config_parser.parse(job.server_config) except JSONDecodeError as err: raise AssertionError("server_config is not a valid json string.") from err model_class = Model.class_for_model(job.model) # Start the server server_uuid, _ = launch_local_server( model=model_class(), n_clients=len(job.clients_info), server_address=job.server_address, redis_host=job.redis_host, redis_port=job.redis_port, **server_config, ) wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER) # Start the clients client_uuids: List[str] = [] for client_info in job.clients_info: parameters = { "server_address": job.server_address, "client": client_info.client.value, "data_path": client_info.data_path, "redis_host": client_info.redis_host, "redis_port": client_info.redis_port, } response = requests.get(url=f"http://{client_info.service_address}/{START_CLIENT_API}", params=parameters) json_response = response.json() LOGGER.debug(f"Client response: {json_response}") if response.status_code != 200: raise Exception(f"Client response returned {response.status_code}. Response: {json_response}") if "uuid" not in json_response: raise Exception(f"Client response did not return a UUID. Response: {json_response}") client_uuids.append(json_response["uuid"]) await job.set_uuids(server_uuid, client_uuids, request.app.database) # Start the server training listener as a background task to update # the job's metrics and status once the training is done background_tasks.add_task(server_training_listener, job, request.app.synchronous_database) for client_info in job.clients_info: background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database) # Return the UUIDs return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) except AssertionError as err: if job is not None: await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database) return JSONResponse(content={"error": str(err)}, status_code=400) except Exception as ex: LOGGER.exception(ex) if job is not None: await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database) return JSONResponse({"error": str(ex)}, status_code=500)
[docs] def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL client. Keeps consuming updates to the channel until it finds `shutdown` in the client metrics. :param job: (Job) The job that has this client's metrics. :param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to. :param database: (pymongo.database.Database) An instance of the database to save the information into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async because of limitations with FastAPI's BackgroundTasks. """ LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}") assert client_info.uuid is not None, "client_info.uuid is None." # check if training has already finished before start listening client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") if client_metrics is not None: LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") job.set_client_metrics(client_info.uuid, client_metrics, database) LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") if "shutdown" in client_metrics: return subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port) # TODO add a max retries mechanism, maybe? for message in subscriber.listen(): # type: ignore[no-untyped-call] if message["type"] == "message": # The contents of the message do not matter, we just use it to get notified client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") if client_metrics is not None: LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") job.set_client_metrics(client_info.uuid, client_metrics, database) LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") if "shutdown" in client_metrics: return
[docs] def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL server. Keeps consuming updates to the channel until it finds `fit_end` in the server metrics, then closes the job with FINISHED_SUCCESSFULLY and saves both the clients and server's metrics to the job in the database. :param job: (Job) The job with the server_uuid to listen to. :param database: (pymongo.database.Database) An instance of the database to save the information into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async because of limitations with FastAPI's BackgroundTasks. """ LOGGER.info(f"Starting listener for server messages from job {job.id} at channel {job.server_uuid}") assert job.server_uuid is not None, "job.server_uuid is None." assert job.redis_host is not None, "job.redis_host is None." assert job.redis_port is not None, "job.redis_port is None." # check if training has already finished before start listening server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}") if server_metrics is not None: LOGGER.info(f"Listener: Updating server metrics for job {job.id}") job.set_server_metrics(server_metrics, database) LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") if "fit_end" in server_metrics: LOGGER.info(f"Listener: Training finished for job {job.id}") job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") return subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port) # TODO add a max retries mechanism, maybe? for message in subscriber.listen(): # type: ignore[no-untyped-call] if message["type"] == "message": # The contents of the message do not matter, we just use it to get notified server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}") if server_metrics is not None: LOGGER.info(f"Listener: Updating server metrics for job {job.id}") job.set_server_metrics(server_metrics, database) LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") if "fit_end" in server_metrics: LOGGER.info(f"Listener: Training finished for job {job.id}") job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") return