Source code for florist.tests.unit.api.routes.server.test_training

import asyncio
import json
from pytest import raises
from typing import Dict, Any, Tuple
from unittest.mock import Mock, AsyncMock, patch, ANY, call

from florist.api.clients.clients import Client
from florist.api.clients.optimizers import Optimizer
from florist.api.db.config import DATABASE_NAME
from florist.api.db.server_entities import Job, JobStatus, JOB_COLLECTION_NAME
from florist.api.monitoring.metrics import get_host_and_port_from_address
from florist.api.models.models import Model
from florist.api.models.mnist import MnistNet
from florist.api.routes.server.training import (
    client_training_listener,
    start,
    server_training_listener
)
from florist.api.servers.strategies import Strategy

[docs] @patch("florist.api.routes.server.training.client_training_listener") @patch("florist.api.routes.server.training.server_training_listener") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_uuids") @patch("florist.api.db.server_entities.Job.set_server_log_file_path") @patch("florist.api.db.server_entities.Job.set_server_pid") async def test_start_success( mock_set_server_pid: Mock, mock_server_log_file_path: Mock, mock_set_uuids: Mock, mock_set_status: Mock, mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_server_training_listener: Mock, mock_client_training_listener: Mock, ) -> None: for test_strategy in Strategy: # Arrange test_job_id = "test-job-id" test_server_config = _get_test_server_config(test_strategy) _, test_job, mock_job_collection, mock_fastapi_request = _setup_test_job_and_mocks() test_job["strategy"] = test_strategy.value test_job["server_config"] = json.dumps(test_server_config) test_job["config_parser"] = test_strategy.get_config_parser().value test_job["client"] = Client.list_by_strategy(test_strategy)[0] test_server_uuid = "test-server-uuid" test_server_log_file_path = "test-log-file-path" test_server_pid = 12345 mock_server_process = Mock() mock_server_process.pid = test_server_pid mock_launch_local_server.return_value = (test_server_uuid, mock_server_process, test_server_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 200 test_client_1_uuid = "test-client-1-uuid" test_client_2_uuid = "test-client-2-uuid" mock_response.json.side_effect = [{"uuid": test_client_1_uuid}, {"uuid": test_client_2_uuid}] mock_requests.get.return_value = mock_response mock_client_training_listener.return_value = AsyncMock() mock_server_training_listener.return_value = AsyncMock() # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 200 json_body = json.loads(response.body.decode()) assert json_body == {"server_uuid": test_server_uuid, "client_uuids": [test_client_1_uuid, test_client_2_uuid]} mock_job_collection.find_one.assert_called_with({"_id": test_job_id}) mock_set_status.assert_called_once_with(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database) mock_launch_local_server.assert_called_once_with( model=ANY, server_factory=Strategy(test_job["strategy"]).get_server_factory(), server_config=test_server_config, server_address=test_job["server_address"], n_clients=len(test_job["clients_info"]), redis_address=test_job["redis_address"], ) assert isinstance(mock_launch_local_server.call_args_list[0][1]["model"], MnistNet) test_redis_host, test_redis_port = get_host_and_port_from_address(test_job["redis_address"]) mock_redis.Redis.assert_called_once_with(host=test_redis_host, port=test_redis_port) mock_redis_connection.get.assert_called_once_with(test_server_uuid) mock_server_log_file_path.assert_called_once_with(test_server_log_file_path, mock_fastapi_request.app.database) mock_requests.get.assert_any_call( url=f"http://{test_job['clients_info'][0]['service_address']}/api/client/start", params={ "server_address": test_job["server_address"], "client": test_job["client"], "model": test_job["model"], "optimizer": test_job["optimizer"], "data_path": test_job["clients_info"][0]["data_path"], "redis_address": test_job["clients_info"][0]["redis_address"], }, ) mock_requests.get.assert_any_call( url=f"http://{test_job['clients_info'][1]['service_address']}/api/client/start", params={ "server_address": test_job["server_address"], "client": test_job["client"], "model": test_job["model"], "optimizer": test_job["optimizer"], "data_path": test_job["clients_info"][1]["data_path"], "redis_address": test_job["clients_info"][1]["redis_address"], }, ) mock_set_uuids.assert_called_once_with( test_server_uuid, [test_client_1_uuid, test_client_2_uuid], mock_fastapi_request.app.database, ) mock_set_server_pid.assert_called_once_with(str(test_server_pid), mock_fastapi_request.app.database) expected_job = Job(**test_job) expected_job.id = ANY expected_job.clients_info[0].id = ANY expected_job.clients_info[1].id = ANY mock_server_training_listener.assert_called_with(expected_job) mock_client_training_listener.assert_has_calls([ call(expected_job, expected_job.clients_info[0]), call(expected_job, expected_job.clients_info[1]), ]) mock_set_server_pid.reset_mock() mock_server_log_file_path.reset_mock() mock_set_uuids.reset_mock() mock_set_status.reset_mock() mock_requests.reset_mock() mock_redis.reset_mock() mock_launch_local_server.reset_mock() mock_server_training_listener.reset_mock() mock_client_training_listener.reset_mock()
[docs] async def test_start_fail_unsupported_server_model() -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["model"] = "WRONG MODEL" # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert "value is not a valid enumeration member" in json_body["error"]
[docs] async def test_start_fail_unsupported_client() -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["client"] = "WRONG CLIENT" # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} assert "value is not a valid enumeration member" in json_body["error"]
[docs] async def test_start_fail_missing_info() -> None: fields_to_be_removed = [ "model", "strategy", "optimizer", "client", "server_config", "clients_info", "server_address", "redis_address", ] for field_to_be_removed in fields_to_be_removed: with patch("florist.api.db.server_entities.Job.set_status") as mock_set_status: with patch("florist.api.db.server_entities.Job.set_error_message") as mock_set_error_message: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() del test_job[field_to_be_removed] # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} error_message = f"Missing Job information: {field_to_be_removed}" assert error_message in json_body["error"] mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") async def test_start_fail_invalid_server_config(mock_set_error_message: Mock, mock_set_status: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["server_config"] = "not json" # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} error_message = f"server_config is not a valid json string." assert error_message in json_body["error"] mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") async def test_start_fail_invalid_client_for_strategy(mock_set_error_message: Mock, mock_set_status: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["client"] = Client.FEDPROX.value # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} error_message = f"Client {Client.FEDPROX} not valid for strategy {Strategy.FEDAVG}." assert error_message in json_body["error"] mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") async def test_start_fail_empty_clients_info(mock_set_error_message: Mock, mock_set_status: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["clients_info"] = [] # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 json_body = json.loads(response.body.decode()) assert json_body == {"error": ANY} error_message = f"Missing Job information: clients_info" assert error_message in json_body["error"] mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") @patch("florist.api.routes.server.training.launch_local_server") async def test_start_launch_server_exception( mock_launch_local_server: Mock, mock_set_error_message: Mock, mock_set_status: Mock, ) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_exception = Exception("test exception") mock_launch_local_server.side_effect = test_exception # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": str(test_exception)} mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(str(test_exception), mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.db.server_entities.Job.set_server_log_file_path") async def test_start_wait_for_metric_exception( mock_set_server_log_file_path: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_error_message: Mock, mock_set_status: Mock, ) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" test_log_file_path = "test-log-file-path" mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) test_exception = Exception("test exception") mock_redis.Redis.side_effect = test_exception # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) assert json_body == {"error": str(test_exception)} mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(str(test_exception), mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep @patch("florist.api.db.server_entities.Job.set_server_log_file_path") async def test_start_wait_for_metric_timeout( mock_set_server_log_file_path: Mock, _: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_error_message: Mock, mock_set_status: Mock, ) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" test_log_file_path = "test-log-file-path" mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"foo\": null}" mock_redis.Redis.return_value = mock_redis_connection # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) error_message = "Metric 'fit_start' not been found after 20 retries." assert json_body == {"error": error_message} mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.server_entities.Job.set_server_log_file_path") async def test_start_fail_response( mock_set_server_log_file_path: Mock, mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_error_message: Mock, mock_set_status: Mock, ) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" test_log_file_path = "test-log-file-path" mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 403 mock_response.json.return_value = "error" mock_requests.get.return_value = mock_response # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) error_message = f"Client response returned 403. Response: error" assert json_body == {"error": error_message} mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.server_entities.Job.set_server_log_file_path") async def test_start_no_client_uuid_in_response( mock_set_server_log_file_path: Mock, mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_error_message: Mock, mock_set_status: Mock, ) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" test_log_file_path = "test-log-file-path" mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"foo": "bar"} mock_requests.get.return_value = mock_response # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) error_message = "Client response did not return a UUID. Response: {'foo': 'bar'}" assert json_body == {"error": error_message} mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.db.server_entities.Job.set_status") @patch("florist.api.db.server_entities.Job.set_error_message") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.server_entities.Job.set_server_log_file_path") async def test_start_client_uuid_in_response_is_not_a_string( mock_set_server_log_file_path: Mock, mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_error_message: Mock, mock_set_status: Mock, ) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" test_log_file_path = "test-log-file-path" mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"uuid": 1234} mock_requests.get.return_value = mock_response # Act response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 json_body = json.loads(response.body.decode()) error_message = "Client UUID is not a string: 1234" assert json_body == {"error": error_message} mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) mock_set_status.assert_has_calls([ call(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database), call(JobStatus.FINISHED_WITH_ERROR, mock_fastapi_request.app.database), ]) mock_set_error_message.assert_called_once_with(error_message, mock_fastapi_request.app.database)
[docs] @patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") async def test_server_training_listener( mock_get_subscriber: Mock, mock_get_from_redis: Mock, mock_motor_client: Mock, ) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_address": "test-redis-host:1234", "clients_info": [ { "service_address": "test-service-address", "uuid": "test-uuid", "redis_address": "test-client-redis-host:1234", "data_path": "test-data-path", } ] }) test_server_metrics = [ {"fit_start": "2022-02-02 02:02:02"}, {"fit_start": "2022-02-02 02:02:02", "rounds": []}, {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"}, ] mock_get_from_redis.side_effect = test_server_metrics mock_subscriber = Mock() mock_subscriber.listen.return_value = [ {"type": "message"}, {"type": "not message"}, {"type": "message"}, {"type": "message"}, {"type": "message"}, ] mock_get_subscriber.return_value = mock_subscriber mock_db_client = make_mock_db_client() mock_motor_client.return_value = mock_db_client with patch.object(Job, "set_status", AsyncMock()) as mock_set_status: with patch.object(Job, "set_server_metrics", AsyncMock()) as mock_set_server_metrics: # Act await server_training_listener(test_job) # Assert mock_set_status.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_db_client[DATABASE_NAME]) assert mock_set_server_metrics.call_count == 3 mock_set_server_metrics.assert_has_calls([ call(test_server_metrics[0], mock_db_client[DATABASE_NAME]), call(test_server_metrics[1], mock_db_client[DATABASE_NAME]), call(test_server_metrics[2], mock_db_client[DATABASE_NAME]), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with(test_job.server_uuid, test_job.redis_address) mock_db_client.close.assert_called()
[docs] @patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") async def test_server_training_listener_already_finished(mock_get_from_redis: Mock, mock_motor_client: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_address": "test-redis-host:1234", "clients_info": [ { "service_address": "test-service-address", "uuid": "test-uuid", "redis_address": "test-client-redis-host:1234", "data_path": "test-data-path", } ] }) test_server_final_metrics = {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_server_final_metrics] mock_db_client = make_mock_db_client() mock_motor_client.return_value = mock_db_client with patch.object(Job, "set_status", AsyncMock()) as mock_set_status: with patch.object(Job, "set_server_metrics", AsyncMock()) as mock_set_server_metrics: # Act await server_training_listener(test_job) # Assert mock_set_status.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_db_client[DATABASE_NAME]) mock_set_server_metrics.assert_called_once_with(test_server_final_metrics, mock_db_client[DATABASE_NAME]) assert mock_get_from_redis.call_count == 1 mock_db_client.close.assert_called()
[docs] async def test_server_training_listener_fail_no_server_uuid() -> None: test_job = Job(**{ "redis_address": "test-redis-host:1234", }) with raises(AssertionError, match="job.server_uuid is None."): await server_training_listener(test_job)
[docs] async def test_server_training_listener_fail_no_redis_address() -> None: test_job = Job(**{ "server_uuid": "test-server-uuid", }) with raises(AssertionError, match="job.redis_address is None."): await server_training_listener(test_job)
[docs] @patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") async def test_client_training_listener( mock_get_subscriber: Mock, mock_get_from_redis: Mock, mock_motor_client: Mock, ) -> None: # Setup test_client_uuid = "test-client-uuid" test_job = Job(**{ "clients_info": [ { "service_address": "test-service-address", "uuid": test_client_uuid, "redis_address": "test-client-redis-host:1234", "data_path": "test-data-path", } ] }) test_client_metrics = [ {"initialized": "2022-02-02 02:02:02"}, {"initialized": "2022-02-02 02:02:02", "rounds": []}, {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"}, ] mock_get_from_redis.side_effect = test_client_metrics mock_subscriber = Mock() mock_subscriber.listen.return_value = [ {"type": "message"}, {"type": "not message"}, {"type": "message"}, {"type": "message"}, {"type": "message"}, ] mock_get_subscriber.return_value = mock_subscriber mock_db_client = make_mock_db_client() mock_motor_client.return_value = mock_db_client with patch.object(Job, "set_client_metrics", AsyncMock()) as mock_set_client_metrics: # Act await client_training_listener(test_job, test_job.clients_info[0]) # Assert assert mock_set_client_metrics.call_count == 3 mock_set_client_metrics.assert_has_calls([ call(test_client_uuid, test_client_metrics[0], mock_db_client[DATABASE_NAME]), call(test_client_uuid, test_client_metrics[1], mock_db_client[DATABASE_NAME]), call(test_client_uuid, test_client_metrics[2], mock_db_client[DATABASE_NAME]), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with( test_job.clients_info[0].uuid, test_job.clients_info[0].redis_address, ) mock_db_client.close.assert_called()
[docs] @patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") async def test_client_training_listener_already_finished(mock_get_from_redis: Mock, mock_motor_client: Mock) -> None: # Setup test_client_uuid = "test-client-uuid" test_job = Job(**{ "clients_info": [ { "service_address": "test-service-address", "uuid": test_client_uuid, "redis_address": "test-client-redis-host:1234", "data_path": "test-data-path", } ] }) test_client_final_metrics = {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_client_final_metrics] mock_db_client = make_mock_db_client() mock_motor_client.return_value = mock_db_client with patch.object(Job, "set_client_metrics", AsyncMock()) as mock_set_client_metrics: # Act await client_training_listener(test_job, test_job.clients_info[0]) # Assert mock_set_client_metrics.assert_called_once_with( test_client_uuid, test_client_final_metrics, mock_db_client[DATABASE_NAME], ) assert mock_get_from_redis.call_count == 1 mock_db_client.close.assert_called()
[docs] async def test_client_training_listener_fail_no_uuid() -> None: test_job = Job(**{ "clients_info": [ { "redis_address": "test-client-redis-host:1234", "service_address": "test-service-address", "data_path": "test-data-path", }, ], }) with raises(AssertionError, match="client_info.uuid is None."): await client_training_listener(test_job, test_job.clients_info[0])
def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, Mock]: test_strategy = Strategy.FEDAVG test_server_config = _get_test_server_config(test_strategy) test_job = { "status": "NOT_STARTED", "model": Model.MNIST.value, "strategy": test_strategy.value, "optimizer": Optimizer.SGD.value, "server_address": "test-server-address", "server_config": json.dumps(test_server_config), "config_parser": "BASIC", "redis_address": "test-redis-host:1234", "server_uuid": "test-server-uuid", "server_metrics": "test-server-metrics", "client": Client.FEDAVG.value, "clients_info": [ { "service_address": "test-service-address-1", "data_path": "test-data-path-1", "redis_address": "test-redis-host-1:12341", "uuid": "test-client-uuids-1", "metrics": "test-client-metrics-1", }, { "service_address": "test-service-address-2", "data_path": "test-data-path-2", "redis_address": "test-redis-host-2:12342", "uuid": "test-client-uuids-2", "metrics": "test-client-metrics-2", }, ], } mock_find_one = asyncio.Future() mock_find_one.set_result(test_job) mock_job_collection = Mock() mock_job_collection.find_one.return_value = mock_find_one mock_fastapi_request = Mock() mock_fastapi_request.app.database = {JOB_COLLECTION_NAME: mock_job_collection} mock_fastapi_request.app.synchronous_database = {JOB_COLLECTION_NAME: mock_job_collection} return test_server_config, test_job, mock_job_collection, mock_fastapi_request def _get_test_server_config(strategy: Strategy) -> Dict[str, Any]: if strategy == Strategy.FEDAVG: return { "n_server_rounds": 2, "batch_size": 8, "local_epochs": 1, } if strategy == Strategy.FEDPROX: return { "n_server_rounds": 123, "batch_size": 456, "local_epochs": 789, "adapt_proximal_weight": True, "initial_proximal_weight": 0.0, "proximal_weight_delta": 0.1, "proximal_weight_patience": 5, } raise ValueError( f"Strategy {strategy.value} not yet supported in tests." + "Please add the model's server config to _get_test_server_config function." )
[docs] def make_mock_db_client() -> Mock: mock_database = Mock() mock_db_client = Mock() mock_db_client.__getitem__ = Mock( side_effect=lambda database_name: mock_database if database_name == DATABASE_NAME else None ) return mock_db_client