Source code for florist.tests.unit.api.servers.test_models

from unittest.mock import ANY

from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.server.adaptive_constraint_servers.fedprox_server import FedProxServer
from fl4health.server.base_server import FlServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from flwr.server.strategy import FedAvg
from flwr.common.typing import Parameters

from florist.api.models.mnist import MnistNet
from florist.api.monitoring.metrics import RedisMetricsReporter
from florist.api.servers.config_parsers import ConfigParser
from florist.api.servers.models import Model, ServerFactory, fit_config_function, get_fedprox_server, get_fedavg_server


[docs] def test_class_for_model(): assert Model.class_for_model(Model.MNIST_FEDAVG) == MnistNet assert Model.class_for_model(Model.MNIST_FEDPROX) == MnistNet
[docs] def test_config_parser_for_model(): assert Model.config_parser_for_model(Model.MNIST_FEDAVG) == ConfigParser.BASIC assert Model.config_parser_for_model(Model.MNIST_FEDPROX) == ConfigParser.FEDPROX
[docs] def test_server_factory_for_model(): test_server_factory = Model.server_factory_for_model(Model.MNIST_FEDAVG) assert test_server_factory.get_server_function == get_fedavg_server assert test_server_factory.model == Model.MNIST_FEDAVG test_server_factory = Model.server_factory_for_model(Model.MNIST_FEDPROX) assert test_server_factory.get_server_function == get_fedprox_server assert test_server_factory.model == Model.MNIST_FEDPROX
[docs] def test_list(): assert Model.list() == [Model.MNIST_FEDAVG.value, Model.MNIST_FEDPROX.value]
[docs] def test_get_server_constructor(): test_n_clients = 2 test_reporters = [RedisMetricsReporter(host="localhost", port="8080")] test_server_config = {"test": 123} test_get_server_function = get_fedavg_server test_model = Model.MNIST_FEDAVG test_server_factory = ServerFactory(get_server_function=test_get_server_function, model=test_model) result = test_server_factory.get_server_constructor( test_n_clients, test_reporters, test_server_config, ) assert result.func == test_get_server_function assert isinstance(result.args[0], Model.class_for_model(test_model)) assert result.args == ( ANY, test_n_clients, test_reporters, test_server_config, )
[docs] def test_fit_config_function(): assert fit_config_function({"test": 123}, 2) == {"test": 123, "current_server_round": 2}
[docs] def test_get_fedavg_server(): test_n_clients = 2 test_reporters = [RedisMetricsReporter(host="localhost", port="8080")] test_server_config = {"test": 123} test_model = MnistNet() result = get_fedavg_server(test_model, test_n_clients, test_reporters, test_server_config) assert isinstance(result, FlServer) assert isinstance(result.strategy, FedAvg) assert result.strategy.min_fit_clients == test_n_clients assert result.strategy.min_evaluate_clients == test_n_clients assert result.strategy.on_fit_config_fn.func == fit_config_function assert result.strategy.on_fit_config_fn.args[0] == test_server_config assert result.strategy.on_evaluate_config_fn.func == fit_config_function assert result.strategy.on_evaluate_config_fn.args[0] == test_server_config assert result.strategy.fit_metrics_aggregation_fn == fit_metrics_aggregation_fn assert result.strategy.evaluate_metrics_aggregation_fn == evaluate_metrics_aggregation_fn assert isinstance(result.strategy.initial_parameters, Parameters) assert isinstance(result._client_manager, SimpleClientManager) assert result.reports_manager.reporters == test_reporters
[docs] def test_get_fedprox_server(): test_n_clients = 2 test_reporters = [RedisMetricsReporter(host="localhost", port="8080")] test_server_config = { "adapt_proximal_weight": True, "initial_proximal_weight": 0.0, "proximal_weight_delta": 0.1, "proximal_weight_patience": 5, } test_model = MnistNet() result = get_fedprox_server(test_model, test_n_clients, test_reporters, test_server_config) assert isinstance(result, FedProxServer) assert isinstance(result.strategy, FedAvgWithAdaptiveConstraint) assert result.strategy.min_fit_clients == test_n_clients assert result.strategy.min_evaluate_clients == test_n_clients assert result.strategy.min_available_clients == test_n_clients assert result.strategy.on_fit_config_fn.func == fit_config_function assert result.strategy.on_fit_config_fn.args[0] == test_server_config assert result.strategy.on_evaluate_config_fn.func == fit_config_function assert result.strategy.on_evaluate_config_fn.args[0] == test_server_config assert result.strategy.fit_metrics_aggregation_fn == fit_metrics_aggregation_fn assert result.strategy.evaluate_metrics_aggregation_fn == evaluate_metrics_aggregation_fn assert result.strategy.adapt_loss_weight == test_server_config["adapt_proximal_weight"] assert result.strategy.loss_weight == test_server_config["initial_proximal_weight"] assert result.strategy.loss_weight_delta == test_server_config["proximal_weight_delta"] assert result.strategy.loss_weight_patience == test_server_config["proximal_weight_patience"] assert isinstance(result.strategy.initial_parameters, Parameters) assert isinstance(result._client_manager, SimpleClientManager) assert result.reports_manager.reporters == test_reporters