Source code for florist.tests.unit.api.servers.test_launch

from unittest.mock import ANY, Mock, patch

from florist.api.clients.mnist import MnistNet
from florist.api.monitoring.logs import get_server_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
from florist.api.servers.launch import launch_local_server
from florist.api.servers.utils import get_server


[docs] @patch("florist.api.servers.launch.launch_server") def test_launch_local_server(mock_launch_server: Mock) -> None: test_model = MnistNet() test_n_clients = 2 test_server_address = "test-server-address" test_n_server_rounds = 5 test_batch_size = 8 test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_server_process = "test-server-process" mock_launch_server.return_value = test_server_process server_uuid, server_process = launch_local_server( test_model, test_n_clients, test_server_address, test_n_server_rounds, test_batch_size, test_local_epochs, test_redis_host, test_redis_port, ) assert server_uuid is not None assert server_process == test_server_process mock_launch_server.assert_called_once() call_args = mock_launch_server.call_args_list[0][0] call_kwargs = mock_launch_server.call_args_list[0][1] assert call_args == ( ANY, test_server_address, test_n_server_rounds, str(get_server_log_file_path(server_uuid)), ) assert call_kwargs == {"seconds_to_sleep": 0} assert call_args[0].func == get_server assert call_args[0].keywords == { "model": test_model, "n_clients": test_n_clients, "batch_size": test_batch_size, "local_epochs": test_local_epochs, "reporters": ANY, } metrics_reporter = call_args[0].keywords["reporters"][0] assert isinstance(metrics_reporter, RedisMetricsReporter) assert metrics_reporter.host == test_redis_host assert metrics_reporter.port == test_redis_port assert metrics_reporter.run_id == server_uuid