from unittest.mock import Mock, patch, ANY
from florist.api.models.mnist import MnistNet
from florist.api.clients.clients import Client, LocalDataClient, FedProxLocalDataClient
from florist.api.clients.optimizers import Optimizer
from florist.api.servers.strategies import Strategy
[docs]
def test_get_client_class():
    assert Client.FEDAVG.get_client_class() == LocalDataClient
    assert Client.FEDPROX.get_client_class() == FedProxLocalDataClient 
[docs]
def test_list():
    assert Client.list() == [Client.FEDAVG.value, Client.FEDPROX.value] 
[docs]
def test_list_by_strategy():
    assert Client.list_by_strategy(Strategy.FEDAVG) == [Client.FEDAVG.value]
    assert Client.list_by_strategy(Strategy.FEDPROX) == [Client.FEDPROX.value] 
[docs]
@patch("florist.api.models.mnist.load_mnist_data")
def test_local_data_model_get_data_loaders(mock_load_mnist_data: Mock):
    test_data_path = "test-data-path"
    test_device = "cpu"
    test_config = {"batch_size": 200}
    test_train_loader = "test-train-loader"
    test_val_loader = "test-val-loader"
    test_client = LocalDataClient(data_path=test_data_path, metrics=[], device=test_device)
    test_client.set_model(MnistNet())
    mock_load_mnist_data.return_value = (test_train_loader, test_val_loader, {})
    train_loader, val_loader = test_client.get_data_loaders(config=test_config)
    assert train_loader == test_train_loader
    assert val_loader == test_val_loader
    mock_load_mnist_data.assert_called_with(test_data_path, test_config["batch_size"], None) 
[docs]
@patch("florist.api.clients.optimizers.torch")
def test_local_data_model_get_optimizer_type(mock_torch: Mock):
    test_optimizer = "test-optimizer"
    mock_torch.optim.SGD.return_value = test_optimizer
    test_client = LocalDataClient(data_path="test-data-path", metrics=[], device="cpu")
    test_client.set_optimizer_type(Optimizer.SGD)
    test_parameters = "test-parameters"
    test_model = MnistNet()
    test_model.parameters = Mock()
    test_model.parameters.return_value = test_parameters
    test_client.set_model(test_model)
    optimizer = test_client.get_optimizer(config={})
    assert optimizer == test_optimizer
    mock_torch.optim.SGD.assert_called_with(test_parameters, lr=0.001, momentum=0.9) 
[docs]
@patch("florist.api.models.mnist.torch")
def test_local_data_model_get_criterion(mock_torch: Mock):
    test_criterion = "test-criterion"
    mock_torch.nn.CrossEntropyLoss.return_value = test_criterion
    test_client = LocalDataClient(data_path="test-data-path", metrics=[], device="cpu")
    test_client.set_model(MnistNet())
    criterion = test_client.get_criterion(config={})
    assert criterion == test_criterion 
[docs]
@patch("florist.api.models.mnist.load_mnist_data")
@patch("florist.api.clients.clients.DirichletLabelBasedSampler")
def test_fedprox_local_data_model_get_data_loaders(mock_sampler: Mock, mock_load_mnist_data: Mock):
    test_data_path = "test-data-path"
    test_device = "cpu"
    test_config = {"batch_size": 200}
    test_train_loader = "test-train-loader"
    test_val_loader = "test-val-loader"
    test_client = FedProxLocalDataClient(data_path=test_data_path, metrics=[], device=test_device)
    test_client.set_model(MnistNet())
    mock_load_mnist_data.return_value = (test_train_loader, test_val_loader, {})
    train_loader, val_loader = test_client.get_data_loaders(config=test_config)
    assert train_loader == test_train_loader
    assert val_loader == test_val_loader
    mock_load_mnist_data.assert_called_with(test_data_path, test_config["batch_size"], ANY)
    mock_sampler.assert_called_with(list(range(10)), sample_percentage=0.75, beta=1)