Source code for florist.api.servers.launch

"""Functions and definitions to launch local servers."""

import uuid
from functools import partial
from multiprocessing import Process
from typing import Tuple

from torch import nn

from florist.api.launchers.local import launch_server
from florist.api.monitoring.logs import get_server_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
from florist.api.servers.utils import get_server


[docs] def launch_local_server( model: nn.Module, n_clients: int, server_address: str, n_server_rounds: int, batch_size: int, local_epochs: int, redis_host: str, redis_port: str, ) -> Tuple[str, Process]: """ Launch a FL server locally. :param model: (torch.nn.Module) The model to be used by the server. Should match the clients' model. :param n_clients: (int) The number of clients that will report to this server. :param server_address: (str) The address the server should start at. :param n_server_rounds: (int) The number of rounds the training should run for. :param batch_size: (int) The size of the batch for training. :param local_epochs: (int) The number of epochs to run by the clients. :param redis_host: (str) the host name for the Redis instance for metrics reporting. :param redis_port: (str) the port for the Redis instance for metrics reporting. :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull metrics from Redis, along with its local process object. """ server_uuid = str(uuid.uuid4()) metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid) server_constructor = partial( get_server, reporters=[metrics_reporter], model=model, n_clients=n_clients, batch_size=batch_size, local_epochs=local_epochs, ) log_file_name = str(get_server_log_file_path(server_uuid)) server_process = launch_server( server_constructor, server_address, n_server_rounds, log_file_name, seconds_to_sleep=0, ) return server_uuid, server_process