Source code for florist.tests.unit.api.models.test_mnist

from unittest.mock import Mock, patch

import torch
from fl4health.utils.sampler import DirichletLabelBasedSampler

from florist.api.models.mnist import MnistNet


[docs] @patch("florist.api.models.mnist.load_mnist_data") def test_get_data_loaders(mock_load_mnist_data: Mock): test_train_loader = "test-train-loader" test_val_loader = "test-val-loader" mock_load_mnist_data.return_value = (test_train_loader, test_val_loader, None) test_data_path = "test-data-path" test_batch_size = 100 test_sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) test_model = MnistNet() train_loader, val_loader = test_model.get_data_loaders(test_data_path, test_batch_size, test_sampler) assert train_loader == test_train_loader assert val_loader == val_loader mock_load_mnist_data.assert_called_once_with(test_data_path, test_batch_size, test_sampler)
[docs] def test_get_criterion(): test_model = MnistNet() criterion = test_model.get_criterion() assert isinstance(criterion, torch.nn.CrossEntropyLoss)