Source code for florist.tests.unit.api.clients.test_clients

from unittest.mock import Mock, patch, ANY

from florist.api.models.mnist import MnistNet
from florist.api.clients.clients import Client, LocalDataClient, FedProxLocalDataClient
from florist.api.clients.optimizers import Optimizer
from florist.api.servers.strategies import Strategy


[docs] def test_get_client_class(): assert Client.FEDAVG.get_client_class() == LocalDataClient assert Client.FEDPROX.get_client_class() == FedProxLocalDataClient
[docs] def test_list(): assert Client.list() == [Client.FEDAVG.value, Client.FEDPROX.value]
[docs] def test_list_by_strategy(): assert Client.list_by_strategy(Strategy.FEDAVG) == [Client.FEDAVG.value] assert Client.list_by_strategy(Strategy.FEDPROX) == [Client.FEDPROX.value]
[docs] @patch("florist.api.models.mnist.load_mnist_data") def test_local_data_model_get_data_loaders(mock_load_mnist_data: Mock): test_data_path = "test-data-path" test_device = "cpu" test_config = {"batch_size": 200} test_train_loader = "test-train-loader" test_val_loader = "test-val-loader" test_client = LocalDataClient(data_path=test_data_path, metrics=[], device=test_device) test_client.set_model(MnistNet()) mock_load_mnist_data.return_value = (test_train_loader, test_val_loader, {}) train_loader, val_loader = test_client.get_data_loaders(config=test_config) assert train_loader == test_train_loader assert val_loader == test_val_loader mock_load_mnist_data.assert_called_with(test_data_path, test_config["batch_size"], None)
[docs] @patch("florist.api.clients.optimizers.torch") def test_local_data_model_get_optimizer_type(mock_torch: Mock): test_optimizer = "test-optimizer" mock_torch.optim.SGD.return_value = test_optimizer test_client = LocalDataClient(data_path="test-data-path", metrics=[], device="cpu") test_client.set_optimizer_type(Optimizer.SGD) test_parameters = "test-parameters" test_model = MnistNet() test_model.parameters = Mock() test_model.parameters.return_value = test_parameters test_client.set_model(test_model) optimizer = test_client.get_optimizer(config={}) assert optimizer == test_optimizer mock_torch.optim.SGD.assert_called_with(test_parameters, lr=0.001, momentum=0.9)
[docs] @patch("florist.api.models.mnist.torch") def test_local_data_model_get_criterion(mock_torch: Mock): test_criterion = "test-criterion" mock_torch.nn.CrossEntropyLoss.return_value = test_criterion test_client = LocalDataClient(data_path="test-data-path", metrics=[], device="cpu") test_client.set_model(MnistNet()) criterion = test_client.get_criterion(config={}) assert criterion == test_criterion
[docs] @patch("florist.api.models.mnist.load_mnist_data") @patch("florist.api.clients.clients.DirichletLabelBasedSampler") def test_fedprox_local_data_model_get_data_loaders(mock_sampler: Mock, mock_load_mnist_data: Mock): test_data_path = "test-data-path" test_device = "cpu" test_config = {"batch_size": 200} test_train_loader = "test-train-loader" test_val_loader = "test-val-loader" test_client = FedProxLocalDataClient(data_path=test_data_path, metrics=[], device=test_device) test_client.set_model(MnistNet()) mock_load_mnist_data.return_value = (test_train_loader, test_val_loader, {}) train_loader, val_loader = test_client.get_data_loaders(config=test_config) assert train_loader == test_train_loader assert val_loader == test_val_loader mock_load_mnist_data.assert_called_with(test_data_path, test_config["batch_size"], ANY) mock_sampler.assert_called_with(list(range(10)), sample_percentage=0.75, beta=1)