Source code for florist.tests.integration.api.launchers.test_launch
import os
import re
import tempfile
from functools import partial
from pathlib import Path
import torch
from florist.api.launchers.local import launch
from florist.api.clients.mnist import MnistClient, MnistNet
from florist.api.servers.utils import get_server
[docs]
def assert_string_in_file(file_path: str, search_string: str) -> bool:
with open(file_path, "r") as f:
file_contents = f.read()
match = re.search(search_string, file_contents)
assert match is not None
[docs]
def test_launch() -> None:
n_clients = 2
n_server_rounds = 2
server_address = "0.0.0.0:8080"
with tempfile.TemporaryDirectory() as temp_dir:
client_data_paths = [Path(f"{temp_dir}/{i}") for i in range(n_clients)]
for client_data_path in client_data_paths:
os.mkdir(client_data_path)
clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths]
server_constructor = partial(get_server, MnistNet(), [])
server_path = os.path.join(temp_dir, "server")
client_base_path = f"{temp_dir}/client"
launch(
server_constructor,
server_address,
n_server_rounds,
clients,
server_path,
client_base_path,
)
assert_string_in_file(f"{server_path}.out", "[SUMMARY]")