Source code for florist.tests.unit.api.routes.server.test_training

import asyncio
import json
from pytest import raises
from typing import Dict, Any, Tuple
from unittest.mock import Mock, patch, ANY, call

from florist.api.db.entities import Job, JobStatus, JOB_COLLECTION_NAME
from florist.api.models.mnist import MnistNet
from florist.api.routes.server.training import (
    client_training_listener,
    start,
    server_training_listener,
    CHECK_CLIENT_STATUS_API,
)


[docs] @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.db.entities.Job.set_uuids") async def test_start_success( mock_set_uuids: Mock, mock_set_status: Mock, mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, ) -> None: # Arrange test_job_id = "test-job-id" test_server_config, test_job, mock_job_collection, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" mock_launch_local_server.return_value = (test_server_uuid, None) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 200 test_client_1_uuid = "test-client-1-uuid" test_client_2_uuid = "test-client-2-uuid" mock_response.json.side_effect = [{"uuid": test_client_1_uuid}, {"uuid": test_client_2_uuid}] mock_requests.get.return_value = mock_response mock_background_tasks = Mock() # Act response = await start(test_job_id, mock_fastapi_request, mock_background_tasks) # Assert assert response.status_code == 200 json_body = json.loads(response.body.decode()) assert json_body == {"server_uuid": test_server_uuid, "client_uuids": [test_client_1_uuid, test_client_2_uuid]} mock_job_collection.find_one.assert_called_with({"_id": test_job_id}) assert isinstance(mock_launch_local_server.call_args_list[0][1]["model"], MnistNet) mock_launch_local_server.assert_called_once_with( model=ANY, n_clients=len(test_job["clients_info"]), server_address=test_job["server_address"], n_server_rounds=test_server_config["n_server_rounds"], batch_size=test_server_config["batch_size"], local_epochs=test_server_config["local_epochs"], redis_host=test_job["redis_host"], redis_port=test_job["redis_port"], ) mock_redis.Redis.assert_called_once_with(host=test_job["redis_host"], port=test_job["redis_port"]) mock_redis_connection.get.assert_called_once_with(test_server_uuid) mock_requests.get.assert_any_call( url=f"http://{test_job['clients_info'][0]['service_address']}/api/client/start", params={ "server_address": test_job["server_address"], "client": test_job["clients_info"][0]["client"], "data_path": test_job["clients_info"][0]["data_path"], "redis_host": test_job["clients_info"][0]["redis_host"], "redis_port": test_job["clients_info"][0]["redis_port"], }, ) mock_requests.get.assert_any_call( url=f"http://{test_job['clients_info'][1]['service_address']}/api/client/start", params={ "server_address": test_job["server_address"], "client": test_job["clients_info"][1]["client"], "data_path": test_job["clients_info"][1]["data_path"], "redis_host": test_job["clients_info"][1]["redis_host"], "redis_port": test_job["clients_info"][1]["redis_port"], }, ) mock_set_status.assert_called_once_with(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database) mock_set_uuids.assert_called_once_with( test_server_uuid, [test_client_1_uuid, test_client_2_uuid], mock_fastapi_request.app.database, ) expected_job = Job(**test_job) expected_job.id = ANY expected_job.clients_info[0].id = ANY expected_job.clients_info[1].id = ANY mock_background_tasks.add_task.assert_has_calls([ call( server_training_listener, expected_job, mock_fastapi_request.app.synchronous_database, ), call( client_training_listener, expected_job, expected_job.clients_info[0], mock_fastapi_request.app.synchronous_database, ), call( client_training_listener, expected_job, expected_job.clients_info[1], mock_fastapi_request.app.synchronous_database, ), ])
[docs] async def test_start_fail_unsupported_server_model() -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["model"] = "WRONG MODEL" # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert "value is not a valid enumeration member" in json_body["error"]
[docs] async def test_start_fail_unsupported_client() -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["clients_info"][1]["client"] = "WRONG CLIENT" # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert "value is not a valid enumeration member" in json_body["error"]
[docs] @patch("florist.api.db.entities.Job.set_status") async def test_start_fail_missing_info(mock_set_status: Mock) -> None: fields_to_be_removed = ["model", "server_config", "clients_info", "server_address", "redis_host", "redis_port"] for field_to_be_removed in fields_to_be_removed: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() del test_job[field_to_be_removed] # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert f"Missing Job information: {field_to_be_removed}" in json_body["error"]
[docs] @patch("florist.api.db.entities.Job.set_status") async def test_start_fail_invalid_server_config(mock_set_status: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["server_config"] = "not json" # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert f"server_config is not a valid json string." in json_body["error"]
[docs] @patch("florist.api.db.entities.Job.set_status") async def test_start_fail_empty_clients_info(_: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["clients_info"] = [] # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert f"Missing Job information: clients_info" in json_body["error"]
[docs] @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") async def test_start_launch_server_exception(mock_launch_local_server: Mock, _: Mock) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_exception = Exception("test exception") mock_launch_local_server.side_effect = test_exception # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": str(test_exception)}
[docs] @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" mock_launch_local_server.return_value = (test_server_uuid, None) test_exception = Exception("test exception") mock_redis.Redis.side_effect = test_exception # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": str(test_exception)}
[docs] @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_status: Mock) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" mock_launch_local_server.return_value = (test_server_uuid, None) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"foo\": null}" mock_redis.Redis.return_value = mock_redis_connection # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": "Metric 'fit_start' not been found after 20 retries."}
[docs] @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" mock_launch_local_server.return_value = (test_server_uuid, None) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 403 mock_response.json.return_value = "error" mock_requests.get.return_value = mock_response # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": f"Client response returned 403. Response: error"}
[docs] @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" mock_launch_local_server.return_value = (test_server_uuid, None) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"foo": "bar"} mock_requests.get.return_value = mock_response # Act response = await start(test_job_id, mock_fastapi_request, Mock()) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": "Client response did not return a UUID. Response: {'foo': 'bar'}"}
[docs] @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") def test_server_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_host": "test-redis-host", "redis_port": "test-redis-port", "clients_info": [ { "service_address": "test-service-address", "uuid": "test-uuid", "redis_host": "test-client-redis-host", "redis_port": "test-client-redis-port", "client": "MNIST", "data_path": "test-data-path", } ] }) test_server_metrics = [ {"fit_start": "2022-02-02 02:02:02"}, {"fit_start": "2022-02-02 02:02:02", "rounds": []}, {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"}, ] mock_get_from_redis.side_effect = test_server_metrics mock_subscriber = Mock() mock_subscriber.listen.return_value = [ {"type": "message"}, {"type": "not message"}, {"type": "message"}, {"type": "message"}, {"type": "message"}, ] mock_get_subscriber.return_value = mock_subscriber mock_database = Mock() with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: # Act server_training_listener(test_job, mock_database) # Assert mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) assert mock_set_server_metrics.call_count == 3 mock_set_server_metrics.assert_has_calls([ call(test_server_metrics[0], mock_database), call(test_server_metrics[1], mock_database), call(test_server_metrics[2], mock_database), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with(test_job.server_uuid, test_job.redis_host, test_job.redis_port)
[docs] @patch("florist.api.routes.server.training.get_from_redis") def test_server_training_listener_already_finished(mock_get_from_redis: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_host": "test-redis-host", "redis_port": "test-redis-port", "clients_info": [ { "service_address": "test-service-address", "uuid": "test-uuid", "redis_host": "test-client-redis-host", "redis_port": "test-client-redis-port", "client": "MNIST", "data_path": "test-data-path", } ] }) test_server_final_metrics = {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_server_final_metrics] mock_database = Mock() with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: # Act server_training_listener(test_job, mock_database) # Assert mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) mock_set_server_metrics.assert_called_once_with(test_server_final_metrics, mock_database) assert mock_get_from_redis.call_count == 1
[docs] def test_server_training_listener_fail_no_server_uuid() -> None: test_job = Job(**{ "redis_host": "test-redis-host", "redis_port": "test-redis-port", }) with raises(AssertionError, match="job.server_uuid is None."): server_training_listener(test_job, Mock())
[docs] def test_server_training_listener_fail_no_redis_host() -> None: test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_port": "test-redis-port", }) with raises(AssertionError, match="job.redis_host is None."): server_training_listener(test_job, Mock())
[docs] def test_server_training_listener_fail_no_redis_port() -> None: test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_host": "test-redis-host", }) with raises(AssertionError, match="job.redis_port is None."): server_training_listener(test_job, Mock())
[docs] @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: # Setup test_client_uuid = "test-client-uuid"; test_job = Job(**{ "clients_info": [ { "service_address": "test-service-address", "uuid": test_client_uuid, "redis_host": "test-client-redis-host", "redis_port": "test-client-redis-port", "client": "MNIST", "data_path": "test-data-path", } ] }) test_client_metrics = [ {"initialized": "2022-02-02 02:02:02"}, {"initialized": "2022-02-02 02:02:02", "rounds": []}, {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"}, ] mock_get_from_redis.side_effect = test_client_metrics mock_subscriber = Mock() mock_subscriber.listen.return_value = [ {"type": "message"}, {"type": "not message"}, {"type": "message"}, {"type": "message"}, {"type": "message"}, ] mock_get_subscriber.return_value = mock_subscriber mock_database = Mock() with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: # Act client_training_listener(test_job, test_job.clients_info[0], mock_database) # Assert assert mock_set_client_metrics.call_count == 3 mock_set_client_metrics.assert_has_calls([ call(test_client_uuid, test_client_metrics[0], mock_database), call(test_client_uuid, test_client_metrics[1], mock_database), call(test_client_uuid, test_client_metrics[2], mock_database), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with( test_job.clients_info[0].uuid, test_job.clients_info[0].redis_host, test_job.clients_info[0].redis_port, )
[docs] @patch("florist.api.routes.server.training.get_from_redis") def test_client_training_listener_already_finished(mock_get_from_redis: Mock) -> None: # Setup test_client_uuid = "test-client-uuid"; test_job = Job(**{ "clients_info": [ { "service_address": "test-service-address", "uuid": test_client_uuid, "redis_host": "test-client-redis-host", "redis_port": "test-client-redis-port", "client": "MNIST", "data_path": "test-data-path", } ] }) test_client_final_metrics = {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_client_final_metrics] mock_database = Mock() with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: # Act client_training_listener(test_job, test_job.clients_info[0], mock_database) # Assert mock_set_client_metrics.assert_called_once_with(test_client_uuid, test_client_final_metrics, mock_database) assert mock_get_from_redis.call_count == 1
[docs] def test_client_training_listener_fail_no_uuid() -> None: test_job = Job(**{ "clients_info": [ { "redis_host": "test-redis-host", "redis_port": "test-redis-port", "service_address": "test-service-address", "client": "MNIST", "data_path": "test-data-path", }, ], }) with raises(AssertionError, match="client_info.uuid is None."): client_training_listener(test_job, test_job.clients_info[0], Mock())
def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, Mock]: test_server_config = { "n_server_rounds": 2, "batch_size": 8, "local_epochs": 1, } test_job = { "status": "NOT_STARTED", "model": "MNIST", "server_address": "test-server-address", "server_config": json.dumps(test_server_config), "config_parser": "BASIC", "redis_host": "test-redis-host", "redis_port": "test-redis-port", "server_uuid": "test-server-uuid", "server_metrics": "test-server-metrics", "clients_info": [ { "client": "MNIST", "service_address": "test-service-address-1", "data_path": "test-data-path-1", "redis_host": "test-redis-host-1", "redis_port": "test-redis-port-1", "uuid": "test-client-uuids-1", "metrics": "test-client-metrics-1", }, { "client": "MNIST", "service_address": "test-service-address-2", "data_path": "test-data-path-2", "redis_host": "test-redis-host-2", "redis_port": "test-redis-port-2", "uuid": "test-client-uuids-2", "metrics": "test-client-metrics-2", }, ], } mock_find_one = asyncio.Future() mock_find_one.set_result(test_job) mock_job_collection = Mock() mock_job_collection.find_one.return_value = mock_find_one mock_fastapi_request = Mock() mock_fastapi_request.app.database = {JOB_COLLECTION_NAME: mock_job_collection} mock_fastapi_request.app.synchronous_database = {JOB_COLLECTION_NAME: mock_job_collection} return test_server_config, test_job, mock_job_collection, mock_fastapi_request