"""Tests for FLorist's client FastAPI endpoints."""
import json
from unittest.mock import ANY, Mock, patch
from florist.api import client
from florist.api.clients.mnist import MnistClient
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
[docs]
def test_connect() -> None:
"""Tests the client's connect endpoint."""
response = client.connect()
assert response.status_code == 200
json_body = json.loads(response.body.decode())
assert json_body == {"status": "ok"}
[docs]
@patch("florist.api.client.launch_client")
def test_start_success(mock_launch_client: Mock) -> None:
test_server_address = "test-server-address"
test_client = "MNIST"
test_data_path = "test/data/path"
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)
assert response.status_code == 200
json_body = json.loads(response.body.decode())
assert json_body == {"uuid": ANY}
log_file_name = str(get_client_log_file_path(json_body["uuid"]))
mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name)
client_obj = mock_launch_client.call_args_list[0][0][0]
assert isinstance(client_obj, MnistClient)
assert str(client_obj.data_path) == test_data_path
metrics_reporter = client_obj.reports_manager.reporters[0]
assert isinstance(metrics_reporter, RedisMetricsReporter)
assert metrics_reporter.host == test_redis_host
assert metrics_reporter.port == test_redis_port
assert metrics_reporter.run_id == json_body["uuid"]
[docs]
def test_start_fail_unsupported_client() -> None:
test_server_address = "test-server-address"
test_client = "WRONG"
test_data_path = "test/data/path"
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)
assert response.status_code == 400
json_body = json.loads(response.body.decode())
assert json_body == {"error": ANY}
assert f"Client '{test_client}' not supported" in json_body["error"]
[docs]
@patch("florist.api.client.launch_client", side_effect=Exception("test exception"))
def test_start_fail_exception(mock_launch_client: Mock) -> None:
test_server_address = "test-server-address"
test_client = "MNIST"
test_data_path = "test/data/path"
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)
assert response.status_code == 500
json_body = json.loads(response.body.decode())
assert json_body == {"error": "test exception"}
[docs]
@patch("florist.api.monitoring.metrics.redis")
def test_check_status(mock_redis: Mock) -> None:
mock_redis_connection = Mock()
mock_redis_connection.get.return_value = b"{\"info\": \"test\"}"
test_uuid = "test_uuid"
test_redis_host = "localhost"
test_redis_port = "testport"
mock_redis.Redis.return_value = mock_redis_connection
response = client.check_status(test_uuid, test_redis_host, test_redis_port)
mock_redis.Redis.assert_called_with(host=test_redis_host, port=test_redis_port)
assert json.loads(response.body.decode()) == {"info": "test"}
[docs]
@patch("florist.api.monitoring.metrics.redis")
def test_check_status_not_found(mock_redis: Mock) -> None:
mock_redis_connection = Mock()
mock_redis_connection.get.return_value = None
test_uuid = "test_uuid"
test_redis_host = "localhost"
test_redis_port = "testport"
mock_redis.Redis.return_value = mock_redis_connection
response = client.check_status(test_uuid, test_redis_host, test_redis_port)
mock_redis.Redis.assert_called_with(host=test_redis_host, port=test_redis_port)
assert response.status_code == 404
assert json.loads(response.body.decode()) == {"error": f"Client {test_uuid} Not Found"}
[docs]
@patch("florist.api.monitoring.metrics.redis.Redis", side_effect=Exception("test exception"))
def test_check_status_fail_exception(mock_redis: Mock) -> None:
test_uuid = "test_uuid"
test_redis_host = "localhost"
test_redis_port = "testport"
response = client.check_status(test_uuid, test_redis_host, test_redis_port)
assert response.status_code == 500
assert json.loads(response.body.decode()) == {"error": "test exception"}