"""FastAPI routes for training."""
import logging
from json import JSONDecodeError
from typing import Any, Dict, List
import requests
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi.responses import JSONResponse
from pymongo.database import Database
from florist.api.db.entities import ClientInfo, Job, JobStatus
from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
from florist.api.servers.launch import launch_local_server
router = APIRouter()
LOGGER = logging.getLogger("uvicorn.error")
START_CLIENT_API = "api/client/start"
CHECK_CLIENT_STATUS_API = "api/client/check_status"
[docs]
@router.post("/start")
async def start(job_id: str, request: Request, background_tasks: BackgroundTasks) -> JSONResponse:
"""
Start FL training for a job id by starting a FL server and its clients.
:param job_id: (str) The id of the Job record in the DB which contains the information
necessary to start training.
:param request: (fastapi.Request) the FastAPI request object.
:param background_tasks: (BackgroundTasks) A BackgroundTasks instance to launch the training listener,
which will update the progress of the training job.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and
the clients in the format below. The UUIDs can be used to pull metrics from Redis.
{
"server_uuid": <client uuid>,
"client_uuids": [<client_uuid_1>, <client_uuid_2>, ..., <client_uuid_n>],
}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
job = None
try:
job = await Job.find_by_id(job_id, request.app.database)
assert job is not None, f"Job with id {job_id} not found."
assert job.status == JobStatus.NOT_STARTED, f"Job status ({job.status.value}) is not NOT_STARTED"
await job.set_status(JobStatus.IN_PROGRESS, request.app.database)
if job.config_parser is None:
job.config_parser = ConfigParser.BASIC
assert job.model is not None, "Missing Job information: model"
assert job.server_config is not None, "Missing Job information: server_config"
assert job.clients_info is not None and len(job.clients_info) > 0, "Missing Job information: clients_info"
assert job.server_address is not None, "Missing Job information: server_address"
assert job.redis_host is not None, "Missing Job information: redis_host"
assert job.redis_port is not None, "Missing Job information: redis_port"
try:
config_parser = ConfigParser.class_for_parser(job.config_parser)
server_config = config_parser.parse(job.server_config)
except JSONDecodeError as err:
raise AssertionError("server_config is not a valid json string.") from err
model_class = Model.class_for_model(job.model)
# Start the server
server_uuid, _ = launch_local_server(
model=model_class(),
n_clients=len(job.clients_info),
server_address=job.server_address,
redis_host=job.redis_host,
redis_port=job.redis_port,
**server_config,
)
wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER)
# Start the clients
client_uuids: List[str] = []
for client_info in job.clients_info:
parameters = {
"server_address": job.server_address,
"client": client_info.client.value,
"data_path": client_info.data_path,
"redis_host": client_info.redis_host,
"redis_port": client_info.redis_port,
}
response = requests.get(url=f"http://{client_info.service_address}/{START_CLIENT_API}", params=parameters)
json_response = response.json()
LOGGER.debug(f"Client response: {json_response}")
if response.status_code != 200:
raise Exception(f"Client response returned {response.status_code}. Response: {json_response}")
if "uuid" not in json_response:
raise Exception(f"Client response did not return a UUID. Response: {json_response}")
client_uuids.append(json_response["uuid"])
await job.set_uuids(server_uuid, client_uuids, request.app.database)
# Start the server training listener as a background task to update
# the job's metrics and status once the training is done
background_tasks.add_task(server_training_listener, job, request.app.synchronous_database)
for client_info in job.clients_info:
background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database)
# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})
except AssertionError as err:
if job is not None:
await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database)
return JSONResponse(content={"error": str(err)}, status_code=400)
except Exception as ex:
LOGGER.exception(ex)
if job is not None:
await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database)
return JSONResponse({"error": str(ex)}, status_code=500)
[docs]
def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL client.
Keeps consuming updates to the channel until it finds `shutdown` in the client metrics.
:param job: (Job) The job that has this client's metrics.
:param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}")
assert client_info.uuid is not None, "client_info.uuid is None."
# check if training has already finished before start listening
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
if "shutdown" in client_metrics:
return
subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port)
# TODO add a max retries mechanism, maybe?
for message in subscriber.listen(): # type: ignore[no-untyped-call]
if message["type"] == "message":
# The contents of the message do not matter, we just use it to get notified
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
if "shutdown" in client_metrics:
return
[docs]
def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL server.
Keeps consuming updates to the channel until it finds `fit_end` in the server metrics,
then closes the job with FINISHED_SUCCESSFULLY and saves both the clients and server's metrics
to the job in the database.
:param job: (Job) The job with the server_uuid to listen to.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Starting listener for server messages from job {job.id} at channel {job.server_uuid}")
assert job.server_uuid is not None, "job.server_uuid is None."
assert job.redis_host is not None, "job.redis_host is None."
assert job.redis_port is not None, "job.redis_port is None."
# check if training has already finished before start listening
server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port)
LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}")
if server_metrics is not None:
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
if "fit_end" in server_metrics:
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
return
subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port)
# TODO add a max retries mechanism, maybe?
for message in subscriber.listen(): # type: ignore[no-untyped-call]
if message["type"] == "message":
# The contents of the message do not matter, we just use it to get notified
server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port)
LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}")
if server_metrics is not None:
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
if "fit_end" in server_metrics:
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
return