Source code for florist.tests.unit.api.models.test_mnist
from unittest.mock import Mock, patch
import torch
from fl4health.utils.sampler import DirichletLabelBasedSampler
from florist.api.models.mnist import MnistNet
[docs]
@patch("florist.api.models.mnist.load_mnist_data")
def test_get_data_loaders(mock_load_mnist_data: Mock):
    test_train_loader = "test-train-loader"
    test_val_loader = "test-val-loader"
    mock_load_mnist_data.return_value = (test_train_loader, test_val_loader, None)
    test_data_path = "test-data-path"
    test_batch_size = 100
    test_sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1)
    test_model = MnistNet()
    train_loader, val_loader = test_model.get_data_loaders(test_data_path, test_batch_size, test_sampler)
    assert train_loader == test_train_loader
    assert val_loader == val_loader
    mock_load_mnist_data.assert_called_once_with(test_data_path, test_batch_size, test_sampler)
[docs]
def test_get_criterion():
    test_model = MnistNet()
    criterion = test_model.get_criterion()
    assert isinstance(criterion, torch.nn.CrossEntropyLoss)