"""Definitions for the MongoDB database entities."""
import json
import uuid
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from fastapi.encoders import jsonable_encoder
from motor.motor_asyncio import AsyncIOMotorDatabase
from pydantic import BaseModel, Field
from pymongo.database import Database
from pymongo.results import UpdateResult
from florist.api.clients.common import Client
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
JOB_COLLECTION_NAME = "job"
MAX_RECORDS_TO_FETCH = 1000
[docs]
class JobStatus(Enum):
"""Enumeration of all possible statuses of a Job."""
NOT_STARTED = "NOT_STARTED"
IN_PROGRESS = "IN_PROGRESS"
FINISHED_WITH_ERROR = "FINISHED_WITH_ERROR"
FINISHED_SUCCESSFULLY = "FINISHED_SUCCESSFULLY"
[docs]
@classmethod
def list(cls) -> List[str]:
"""
List all the valid statuses.
:return: (List[str]) a list of valid job statuses.
"""
return [status.value for status in JobStatus]
[docs]
class ClientInfo(BaseModel):
"""Define the information of an FL client."""
id: str = Field(default_factory=uuid.uuid4, alias="_id")
client: Client = Field(...)
service_address: str = Field(...)
data_path: str = Field(...)
redis_host: str = Field(...)
redis_port: str = Field(...)
uuid: Optional[Annotated[str, Field(...)]]
metrics: Optional[Annotated[str, Field(...)]]
[docs]
class Config:
"""MongoDB config for the ClientInfo DB entity."""
allow_population_by_field_name = True
schema_extra = {
"example": {
"client": "MNIST",
"service_address": "localhost:8001",
"data_path": "path/to/data",
"redis_host": "localhost",
"redis_port": "6380",
"uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
"metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
},
}
[docs]
class Job(BaseModel):
"""Define the Job DB entity."""
id: str = Field(default_factory=uuid.uuid4, alias="_id")
status: JobStatus = Field(default=JobStatus.NOT_STARTED)
model: Optional[Annotated[Model, Field(...)]]
server_address: Optional[Annotated[str, Field(...)]]
server_config: Optional[Annotated[str, Field(...)]]
config_parser: Optional[Annotated[ConfigParser, Field(...)]]
server_uuid: Optional[Annotated[str, Field(...)]]
server_metrics: Optional[Annotated[str, Field(...)]]
redis_host: Optional[Annotated[str, Field(...)]]
redis_port: Optional[Annotated[str, Field(...)]]
clients_info: Optional[Annotated[List[ClientInfo], Field(...)]]
[docs]
@classmethod
async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase[Any]) -> Optional["Job"]:
"""
Find a job in the database by its id.
:param job_id: (str) the job's id.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
:return: (Optional[Job]) An instance of the job record with the given ID, or `None` if it can't be found.
"""
job_collection = database[JOB_COLLECTION_NAME]
result = await job_collection.find_one({"_id": job_id})
if result is None:
return result
return Job(**result)
[docs]
@classmethod
async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMotorDatabase[Any]) -> List["Job"]:
"""
Return all jobs with the given status.
:param status: (JobStatus) The status of the jobs to be returned.
:param limit: (int) the limit amount of records that should be returned.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
:return: (List[Job]) The list of jobs with the given status in the database.
"""
status = jsonable_encoder(status)
job_collection = database[JOB_COLLECTION_NAME]
result = await job_collection.find({"status": status}).to_list(limit)
assert isinstance(result, list)
return [Job(**r) for r in result]
[docs]
async def create(self, database: AsyncIOMotorDatabase[Any]) -> str:
"""
Save this instance under a new record in the database.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
:return: (str) the new job record's id.
"""
json_job = jsonable_encoder(self)
result = await database[JOB_COLLECTION_NAME].insert_one(json_job)
assert isinstance(result.inserted_id, str)
return result.inserted_id
[docs]
async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the server and clients' UUIDs in the database under the current job's id.
:param server_uuid: [str] the server_uuid to be saved in the database.
:param client_uuids: List[str] the list of client_uuids to be saved in the database.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
assert self.clients_info is not None and len(self.clients_info) == len(client_uuids), (
"self.clients_info and client_uuids must have the same length "
f"({'None' if self.clients_info is None else len(self.clients_info)}!={len(client_uuids)})."
)
job_collection = database[JOB_COLLECTION_NAME]
self.server_uuid = server_uuid
update_result = await job_collection.update_one({"_id": self.id}, {"$set": {"server_uuid": server_uuid}})
assert_updated_successfully(update_result)
for i in range(len(client_uuids)):
self.clients_info[i].uuid = client_uuids[i]
update_result = await job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.uuid": client_uuids[i]}}
)
assert_updated_successfully(update_result)
[docs]
async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the status in the database under the current job's id.
:param status: (JobStatus) the status to be saved in the database.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
job_collection = database[JOB_COLLECTION_NAME]
self.status = status
update_result = await job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}})
assert_updated_successfully(update_result)
[docs]
def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]]) -> None:
"""
Sync function to save the status in the database under the current job's id.
:param status: (JobStatus) the status to be saved in the database.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
job_collection = database[JOB_COLLECTION_NAME]
self.status = status
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}})
assert_updated_successfully(update_result)
[docs]
def set_server_metrics(
self,
server_metrics: Dict[str, Any],
database: Database[Dict[str, Any]],
) -> None:
"""
Sync function to save the server's metrics in the database under the current job's id.
:param server_metrics: (Dict[str, Any]) the server metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
job_collection = database[JOB_COLLECTION_NAME]
self.server_metrics = json.dumps(server_metrics)
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}})
assert_updated_successfully(update_result)
[docs]
def set_client_metrics(
self,
client_uuid: str,
client_metrics: Dict[str, Any],
database: Database[Dict[str, Any]],
) -> None:
"""
Sync function to save a clients' metrics in the database under the current job's id.
:param client_uuid: (str) the client's uuid whose produced the metrics.
:param client_metrics: (Dict[str, Any]) the client's metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
assert (
self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info]
), f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})"
job_collection = database[JOB_COLLECTION_NAME]
for i in range(len(self.clients_info)):
if client_uuid == self.clients_info[i].uuid:
self.clients_info[i].metrics = json.dumps(client_metrics)
update_result = job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}
)
assert_updated_successfully(update_result)
[docs]
class Config:
"""MongoDB config for the Job DB entity."""
allow_population_by_field_name = True
schema_extra = {
"example": {
"_id": "066de609-b04a-4b30-b46c-32537c7f1f6e",
"status": "NOT_STARTED",
"model": "MNIST",
"server_address": "localhost:8000",
"server_config": '{"n_server_rounds": 3, "batch_size": 8, "local_epochs": 1}',
"server_uuid": "d73243cf-8b89-473b-9607-8cd0253a101d",
"server_metrics": '{"host_type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}',
"redis_host": "localhost",
"redis_port": "6379",
"clients_info": [
{
"client": "MNIST",
"service_address": "localhost:8001",
"data_path": "path/to/data",
"redis_host": "localhost",
"redis_port": "6380",
"client_uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
},
],
},
}
[docs]
def assert_updated_successfully(update_result: UpdateResult) -> None:
"""
Assert an update result has updated exactly one record.
:param update_result: (pymongo.results.UpdateResult) the result object from an update.
"""
raw_result = update_result.raw_result
assert isinstance(raw_result, dict)
assert raw_result["n"] == 1, f"UpdateResult's 'n' is not 1 ({update_result})"
assert raw_result["nModified"] in [1, 0], f"UpdateResult's 'nModified' is not 1 or 0 ({update_result})"
assert raw_result["ok"] == 1, f"UpdateResult's 'ok' is not 1 ({update_result})"