Source code for florist.tests.unit.api.servers.test_launch

from unittest.mock import ANY, Mock, patch

from florist.api.monitoring.metrics import RedisMetricsReporter
from florist.api.servers.launch import launch_local_server
from florist.api.servers.models import Model, ServerFactory, get_fedavg_server


[docs] @patch("florist.api.servers.launch.launch_server") @patch("florist.api.servers.launch.uuid") def test_launch_local_server(mock_uuid: Mock, mock_launch_server: Mock) -> None: test_model = Model.MNIST_FEDAVG test_n_clients = 2 test_server_address = "test-server-address" test_server_config = { "n_server_rounds": 5, "batch_size": 8, "local_epochs": 1, } test_server_factory = ServerFactory(get_server_function=get_fedavg_server, model=test_model) test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_server_process = "test-server-process" mock_launch_server.return_value = test_server_process test_server_uuid = "test-server-uuid" mock_uuid.uuid4.return_value = test_server_uuid server_uuid, server_process, log_file_path = launch_local_server( test_server_factory, test_server_config, test_server_address, test_n_clients, test_redis_host, test_redis_port, ) assert server_uuid is not None assert server_process == test_server_process mock_launch_server.assert_called_once() call_args = mock_launch_server.call_args_list[0][0] call_kwargs = mock_launch_server.call_args_list[0][1] assert call_args == ( ANY, test_server_address, test_server_config["n_server_rounds"], log_file_path, ) assert call_kwargs == {"seconds_to_sleep": 0} expected_server_constructor = test_server_factory.get_server_constructor( test_n_clients, [RedisMetricsReporter(host=test_redis_host, port=test_redis_port, run_id=test_server_uuid)], test_server_config, ) assert call_args[0].func == expected_server_constructor.func assert call_args[0].args == ( ANY, # model can't be compared with __eq__, so put ANY here and compare later expected_server_constructor.args[1], expected_server_constructor.args[2], expected_server_constructor.args[3], ) assert isinstance(call_args[0].args[0], expected_server_constructor.args[0].__class__)