Source code for florist.tests.unit.api.clients.test_optimizers

from unittest.mock import Mock, patch

from florist.api.clients.optimizers import Optimizer


[docs] def test_optimizer_list(): assert Optimizer.list() == [Optimizer.SGD.value, Optimizer.ADAM_W.value]
[docs] @patch("florist.api.clients.optimizers.torch") def test_optimizer_get_sgd(mock_torch: Mock): test_model_parameters = Mock() test_optimizer = "test-optimizer" mock_torch.optim.SGD.return_value = test_optimizer optimizer = Optimizer.get(Optimizer.SGD, test_model_parameters) assert optimizer == test_optimizer mock_torch.optim.SGD.assert_called_with(test_model_parameters, lr=0.001, momentum=0.9)
[docs] @patch("florist.api.clients.optimizers.torch") def test_optimizer_get_adam_w(mock_torch: Mock): test_model_parameters = Mock() test_optimizer = "test-optimizer" mock_torch.optim.AdamW.return_value = test_optimizer optimizer = Optimizer.get(Optimizer.ADAM_W, test_model_parameters) assert optimizer == test_optimizer mock_torch.optim.AdamW.assert_called_with(test_model_parameters, lr=0.01)