Source code for florist.tests.integration.api.db.test_entities

import json
import re
from unittest.mock import ANY
from pytest import raises

from florist.api.db.entities import Job, JobStatus
from florist.tests.integration.api.utils import mock_request


[docs] async def test_job_create_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) assert isinstance(result_id, str)
[docs] async def test_job_find_by_id_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY result_job = await Job.find_by_id(result_id, mock_request.app.database) assert test_job == result_job
[docs] async def test_job_find_by_id_not_found(mock_request) -> None: result_job = await Job.find_by_id("does-not-exist", mock_request.app.database) assert result_job is None
[docs] async def test_job_find_by_status_success(mock_request) -> None: test_job = get_test_job() test_job.status = JobStatus.FINISHED_SUCCESSFULLY result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY result_jobs = await Job.find_by_status(JobStatus.FINISHED_SUCCESSFULLY, 10, mock_request.app.database) assert len(result_jobs) == 1 assert test_job == result_jobs[0] result_jobs = await Job.find_by_status(JobStatus.NOT_STARTED, 10, mock_request.app.database) assert len(result_jobs) == 0
[docs] async def test_job_find_by_status_with_limit_success(mock_request) -> None: for i in range(4): test_job = get_test_job() test_job.status = JobStatus.FINISHED_SUCCESSFULLY await test_job.create(mock_request.app.database) result_jobs = await Job.find_by_status(JobStatus.FINISHED_SUCCESSFULLY, 3, mock_request.app.database) assert len(result_jobs) == 3
[docs] async def test_set_uuids_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY test_server_uuid = "a-different-server-uuid" test_client_uuids = ["a-different-client-uuid-1", "a-different-client-uuid-2"] await test_job.set_uuids(test_server_uuid, test_client_uuids, mock_request.app.database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.server_uuid = test_server_uuid test_job.clients_info[0].uuid = test_client_uuids[0] test_job.clients_info[1].uuid = test_client_uuids[1] assert result_job == test_job
[docs] async def test_set_uuids_fail_clients_info_is_none(mock_request) -> None: test_job = get_test_job() test_job.clients_info = None result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_server_uuid = "a-different-server-uuid" test_client_uuids = ["a-different-client-uuid-1", "a-different-client-uuid-2"] error_msg = "self.clients_info and client_uuids must have the same length (None!=2)." with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_uuids(test_server_uuid, test_client_uuids, mock_request.app.database)
[docs] async def test_set_uuids_fail_clients_info_is_not_same_length(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY test_server_uuid = "a-different-server-uuid" test_client_uuids = ["a-different-client-uuid-1"] error_msg = "self.clients_info and client_uuids must have the same length (2!=1)." with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_uuids(test_server_uuid, test_client_uuids, mock_request.app.database)
[docs] async def test_set_uuids_fail_update_result(mock_request) -> None: test_job = get_test_job() test_job.id = str(test_job.id) test_server_uuid = "a-different-server-uuid" test_client_uuids = ["a-different-client-uuid-1", "a-different-client-uuid-2"] error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_uuids(test_server_uuid, test_client_uuids, mock_request.app.database)
[docs] async def test_set_status_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY test_status = JobStatus.IN_PROGRESS await test_job.set_status(test_status, mock_request.app.database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.status = test_status assert result_job == test_job
[docs] async def test_set_status_fail_update_result(mock_request) -> None: test_job = get_test_job() test_job.id = str(test_job.id) error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_status(JobStatus.IN_PROGRESS, mock_request.app.database)
[docs] async def test_set_server_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY test_server_metrics = {"test-server": 123} await test_job.set_server_metrics(test_server_metrics, mock_request.app.database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.server_metrics = json.dumps(test_server_metrics) assert result_job == test_job
[docs] async def test_set_server_metrics_fail_update_result(mock_request) -> None: test_job = get_test_job() test_job.id = str(test_job.id) test_server_metrics = {"test-server": 123} error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_server_metrics(test_server_metrics, mock_request.app.database)
[docs] async def test_set_client_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] await test_job.set_client_metrics(test_job.clients_info[1].uuid, test_client_metrics, mock_request.app.database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.clients_info[1].metrics = json.dumps(test_client_metrics) assert result_job == test_job
[docs] async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_wrong_client_uuid = "client-id-that-does-not-exist" test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] error_msg = f"client uuid {test_wrong_client_uuid} is not in clients_info (['{test_job.clients_info[0].uuid}', '{test_job.clients_info[1].uuid}'])" with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_client_metrics(test_wrong_client_uuid, test_client_metrics, mock_request.app.database)
[docs] async def test_set_client_metrics_fail_update_result(mock_request) -> None: test_job = get_test_job() test_job.id = str(test_job.id) test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): await test_job.set_client_metrics( test_job.clients_info[0].uuid, test_client_metrics, mock_request.app.database, )
[docs] def get_test_job() -> Job: test_server_config = { "n_server_rounds": 2, "batch_size": 8, "local_epochs": 1, } return Job(**{ "status": "NOT_STARTED", "model": "MNIST", "server_address": "test-server-address", "server_config": json.dumps(test_server_config), "config_parser": "BASIC", "redis_host": "test-redis-host", "redis_port": "test-redis-port", "server_uuid": "test-server-uuid", "server_metrics": "test-server-metrics", "clients_info": [ { "client": "MNIST", "service_address": "test-service-address-1", "data_path": "test-data-path-1", "redis_host": "test-redis-host-1", "redis_port": "test-redis-port-1", "uuid": "test-client-uuids-1", "metrics": "test-client-metrics-1", }, { "client": "MNIST", "service_address": "test-service-address-2", "data_path": "test-data-path-2", "redis_host": "test-redis-host-2", "redis_port": "test-redis-port-2", "uuid": "test-client-uuids-2", "metrics": "test-client-metrics-2", }, ], })