Source code for florist.api.routes.server.training

"""FastAPI routes for training."""

import asyncio
import logging
from json import JSONDecodeError
from threading import Thread
from typing import Any, List

import requests
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from motor.motor_asyncio import AsyncIOMotorClient

from florist.api.db.config import DATABASE_NAME, MONGODB_URI
from florist.api.db.entities import ClientInfo, Job, JobStatus
from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
from florist.api.servers.launch import launch_local_server


router = APIRouter()

LOGGER = logging.getLogger("uvicorn.error")

START_CLIENT_API = "api/client/start"
CHECK_CLIENT_STATUS_API = "api/client/check_status"


[docs] @router.post("/start") async def start(job_id: str, request: Request) -> JSONResponse: """ Start FL training for a job id by starting a FL server and its clients. :param job_id: (str) The id of the Job record in the DB which contains the information necessary to start training. :param request: (fastapi.Request) the FastAPI request object. :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and the clients in the format below. The UUIDs can be used to pull metrics from Redis. { "server_uuid": <client uuid>, "client_uuids": [<client_uuid_1>, <client_uuid_2>, ..., <client_uuid_n>], } If not successful, returns the appropriate error code with a JSON with the format below: {"error": <error message>} """ job = None try: job = await Job.find_by_id(job_id, request.app.database) assert job is not None, f"Job with id {job_id} not found." assert job.status == JobStatus.NOT_STARTED, f"Job status ({job.status.value}) is not NOT_STARTED" await job.set_status(JobStatus.IN_PROGRESS, request.app.database) if job.config_parser is None: job.config_parser = ConfigParser.BASIC assert job.model is not None, "Missing Job information: model" assert job.server_config is not None, "Missing Job information: server_config" assert job.clients_info is not None and len(job.clients_info) > 0, "Missing Job information: clients_info" assert job.server_address is not None, "Missing Job information: server_address" assert job.redis_host is not None, "Missing Job information: redis_host" assert job.redis_port is not None, "Missing Job information: redis_port" try: config_parser = ConfigParser.class_for_parser(job.config_parser) server_config = config_parser.parse(job.server_config) except JSONDecodeError as err: raise AssertionError("server_config is not a valid json string.") from err model_class = Model.class_for_model(job.model) # Start the server server_uuid, _ = launch_local_server( model=model_class(), n_clients=len(job.clients_info), server_address=job.server_address, redis_host=job.redis_host, redis_port=job.redis_port, **server_config, ) wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER) # Start the clients client_uuids: List[str] = [] for client_info in job.clients_info: parameters = { "server_address": job.server_address, "client": client_info.client.value, "data_path": client_info.data_path, "redis_host": client_info.redis_host, "redis_port": client_info.redis_port, } response = requests.get(url=f"http://{client_info.service_address}/{START_CLIENT_API}", params=parameters) json_response = response.json() LOGGER.debug(f"Client response: {json_response}") if response.status_code != 200: raise Exception(f"Client response returned {response.status_code}. Response: {json_response}") if "uuid" not in json_response: raise Exception(f"Client response did not return a UUID. Response: {json_response}") client_uuids.append(json_response["uuid"]) await job.set_uuids(server_uuid, client_uuids, request.app.database) # Start the server training listener and client training listeners as threads to update # the job's metrics and status once the training is done server_listener_thread = Thread(target=asyncio.run, args=(server_training_listener(job),)) server_listener_thread.start() for client_info in job.clients_info: client_listener_thread = Thread(target=asyncio.run, args=(client_training_listener(job, client_info),)) client_listener_thread.start() # Return the UUIDs return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) except AssertionError as err: if job is not None: await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database) return JSONResponse(content={"error": str(err)}, status_code=400) except Exception as ex: LOGGER.exception(ex) if job is not None: await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database) return JSONResponse({"error": str(ex)}, status_code=500)
[docs] async def client_training_listener(job: Job, client_info: ClientInfo) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL client. Keeps consuming updates to the channel until it finds `shutdown` in the client metrics. :param job: (Job) The job that has this client's metrics. :param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to. """ LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}") assert client_info.uuid is not None, "client_info.uuid is None." db_client: AsyncIOMotorClient[Any] = AsyncIOMotorClient(MONGODB_URI) database = db_client[DATABASE_NAME] # check if training has already finished before start listening client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}") if client_metrics is not None: LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}") await job.set_client_metrics(client_info.uuid, client_metrics, database) LOGGER.info(f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated.") if "shutdown" in client_metrics: db_client.close() return subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port) # TODO add a max retries mechanism, maybe? for message in subscriber.listen(): # type: ignore[no-untyped-call] if message["type"] == "message": # The contents of the message do not matter, we just use it to get notified client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}") if client_metrics is not None: LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}") await job.set_client_metrics(client_info.uuid, client_metrics, database) LOGGER.info( f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated." ) if "shutdown" in client_metrics: db_client.close() return db_client.close()
[docs] async def server_training_listener(job: Job) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL server. Keeps consuming updates to the channel until it finds `fit_end` in the server metrics, then closes the job with FINISHED_SUCCESSFULLY and saves both the clients and server's metrics to the job in the database. :param job: (Job) The job with the server_uuid to listen to. """ LOGGER.info(f"Starting listener for server messages from job {job.id} at channel {job.server_uuid}") assert job.server_uuid is not None, "job.server_uuid is None." assert job.redis_host is not None, "job.redis_host is None." assert job.redis_port is not None, "job.redis_port is None." db_client: AsyncIOMotorClient[Any] = AsyncIOMotorClient(MONGODB_URI) database = db_client[DATABASE_NAME] # check if training has already finished before start listening server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) LOGGER.debug(f"Server listener: Current metrics for job {job.id}: {server_metrics}") if server_metrics is not None: LOGGER.info(f"Server listener: Updating server metrics for job {job.id}") await job.set_server_metrics(server_metrics, database) LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.") if "fit_end" in server_metrics: LOGGER.info(f"Server listener: Training finished for job {job.id}") await job.set_status(JobStatus.FINISHED_SUCCESSFULLY, database) LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.") db_client.close() return subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port) # TODO add a max retries mechanism, maybe? for message in subscriber.listen(): # type: ignore[no-untyped-call] if message["type"] == "message": # The contents of the message do not matter, we just use it to get notified server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) LOGGER.debug(f"Server listener: Message received for job {job.id}. Metrics: {server_metrics}") if server_metrics is not None: LOGGER.info(f"Server listener: Updating server metrics for job {job.id}") await job.set_server_metrics(server_metrics, database) LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.") if "fit_end" in server_metrics: LOGGER.info(f"Server listener: Training finished for job {job.id}") await job.set_status(JobStatus.FINISHED_SUCCESSFULLY, database) LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.") db_client.close() return db_client.close()