"""Tests for FLorist's client FastAPI endpoints."""
import json
import os
import signal
from unittest.mock import ANY, Mock, patch
import pytest
from fl4health.utils.metrics import Accuracy
from florist.api import client
from florist.api.clients.common import Client
from florist.api.clients.mnist import MnistClient
from florist.api.db.client_entities import ClientDAO
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
[docs]
@pytest.fixture(autouse=True)
async def mock_client_db() -> None:
test_sqlite_db_path = "florist/tests/unit/api/client.db"
print(f"Creating test detabase '{test_sqlite_db_path}'")
real_db_path = ClientDAO.db_path
ClientDAO.db_path = test_sqlite_db_path
yield
ClientDAO.db_path = real_db_path
if os.path.exists(test_sqlite_db_path):
print(f"Deleting test detabase '{test_sqlite_db_path}'")
os.remove(test_sqlite_db_path)
[docs]
def test_connect() -> None:
"""Tests the client's connect endpoint."""
response = client.connect()
assert response.status_code == 200
json_body = json.loads(response.body.decode())
assert json_body == {"status": "ok"}
[docs]
@patch("florist.api.client.launch_client")
def test_start_success(mock_launch_client: Mock) -> None:
test_server_address = "test-server-address"
test_client = Client.MNIST
test_data_path = "test/data/path"
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_client_pid = 1234
mock_client_process = Mock()
mock_client_process.pid = test_client_pid
mock_launch_client.return_value = mock_client_process
response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)
assert response.status_code == 200
json_body = json.loads(response.body.decode())
log_file_path = str(get_client_log_file_path(json_body["uuid"]))
assert json_body == {"uuid": ANY}
mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_path)
client_obj = mock_launch_client.call_args_list[0][0][0]
assert isinstance(client_obj, MnistClient)
assert str(client_obj.data_path) == test_data_path
assert len(client_obj.metrics) == 1
assert isinstance(client_obj.metrics[0], Accuracy)
metrics_reporter = client_obj.reports_manager.reporters[0]
assert isinstance(metrics_reporter, RedisMetricsReporter)
assert metrics_reporter.host == test_redis_host
assert metrics_reporter.port == test_redis_port
assert metrics_reporter.run_id == json_body["uuid"]
client_dao = ClientDAO.find(uuid=json_body["uuid"])
assert client_dao.pid == test_client_pid
assert client_dao.log_file_path == log_file_path
[docs]
@patch("florist.api.client.launch_client", side_effect=Exception("test exception"))
def test_start_fail_exception(_: Mock) -> None:
test_server_address = "test-server-address"
test_client = Client.MNIST
test_data_path = "test/data/path"
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)
assert response.status_code == 500
json_body = json.loads(response.body.decode())
assert json_body == {"error": "test exception"}
[docs]
@patch("florist.api.monitoring.metrics.redis")
def test_check_status(mock_redis: Mock) -> None:
mock_redis_connection = Mock()
mock_redis_connection.get.return_value = b"{\"info\": \"test\"}"
test_uuid = "test_uuid"
test_redis_host = "localhost"
test_redis_port = "testport"
mock_redis.Redis.return_value = mock_redis_connection
response = client.check_status(test_uuid, test_redis_host, test_redis_port)
mock_redis.Redis.assert_called_with(host=test_redis_host, port=test_redis_port)
assert json.loads(response.body.decode()) == {"info": "test"}
[docs]
@patch("florist.api.monitoring.metrics.redis")
def test_check_status_not_found(mock_redis: Mock) -> None:
mock_redis_connection = Mock()
mock_redis_connection.get.return_value = None
test_uuid = "test_uuid"
test_redis_host = "localhost"
test_redis_port = "testport"
mock_redis.Redis.return_value = mock_redis_connection
response = client.check_status(test_uuid, test_redis_host, test_redis_port)
mock_redis.Redis.assert_called_with(host=test_redis_host, port=test_redis_port)
assert response.status_code == 404
assert json.loads(response.body.decode()) == {"error": f"Client {test_uuid} Not Found"}
[docs]
@patch("florist.api.monitoring.metrics.redis.Redis", side_effect=Exception("test exception"))
def test_check_status_fail_exception(_: Mock) -> None:
test_uuid = "test_uuid"
test_redis_host = "localhost"
test_redis_port = "testport"
response = client.check_status(test_uuid, test_redis_host, test_redis_port)
assert response.status_code == 500
assert json.loads(response.body.decode()) == {"error": "test exception"}
[docs]
def test_get_log() -> None:
test_client_uuid = "test-client-uuid"
test_log_file_content = "this is a test log file content"
test_log_file_path = str(get_client_log_file_path(test_client_uuid))
with open(test_log_file_path, "w") as f:
f.write(test_log_file_content)
client_dao = ClientDAO(uuid=test_client_uuid, log_file_path=test_log_file_path)
client_dao.save()
response = client.get_log(test_client_uuid)
assert response.status_code == 200
assert response.body.decode() == f"\"{test_log_file_content}\""
os.remove(test_log_file_path)
[docs]
def test_get_log_no_log_file_path() -> None:
test_client_uuid = "test-client-uuid"
client_dao = ClientDAO(uuid=test_client_uuid)
client_dao.save()
response = client.get_log(test_client_uuid)
assert response.status_code == 400
assert json.loads(response.body.decode()) == {"error": "Client log file path is None or empty"}
[docs]
@patch("florist.api.client.ClientDAO")
def test_get_log_exception(mock_client_dao) -> None:
test_client_uuid = "test-client-uuid"
test_exception_message = "test-exception-message"
mock_client_dao.find.side_effect = Exception(test_exception_message)
response = client.get_log(test_client_uuid)
assert response.status_code == 500
assert json.loads(response.body.decode()) == {"error": test_exception_message}
[docs]
@patch("florist.api.client.os.kill")
def test_stop_success(mock_kill: Mock) -> None:
test_client_uuid = "test-client-uuid"
test_pid = 1234
client_dao = ClientDAO(uuid=test_client_uuid, pid=test_pid)
client_dao.save()
response = client.stop(test_client_uuid)
assert response.status_code == 200
assert json.loads(response.body.decode()) == {"status": "success"}
mock_kill.assert_called_once_with(test_pid, signal.SIGTERM)
[docs]
def test_stop_fail_no_uuid() -> None:
response = client.stop("")
assert response.status_code == 400
assert json.loads(response.body.decode()) == {"error": "UUID is empty or None."}
[docs]
def test_stop_fail_not_found() -> None:
test_uuid = "inexistant-uuid"
client_dao = ClientDAO(uuid="test-client-uuid", pid=1234)
client_dao.save()
response = client.stop(test_uuid)
assert response.status_code == 500
assert json.loads(response.body.decode()) == {"error": f"Client with uuid '{test_uuid}' not found."}
[docs]
@patch("florist.api.client.os.kill")
def test_stop_fail_exception(mock_kill: Mock) -> None:
test_client_uuid = "test-client-uuid"
test_pid = 1234
test_exception_message = "test-exception-message"
mock_kill.side_effect = Exception(test_exception_message)
client_dao = ClientDAO(uuid=test_client_uuid, pid=test_pid)
client_dao.save()
response = client.stop(test_client_uuid)
assert response.status_code == 500
assert json.loads(response.body.decode()) == {"error": test_exception_message}
mock_kill.assert_called_once_with(test_pid, signal.SIGTERM)