Source code for florist.tests.unit.api.clients.test_optimizers
from unittest.mock import Mock, patch
from florist.api.clients.optimizers import Optimizer
[docs]
def test_optimizer_list():
    assert Optimizer.list() == [Optimizer.SGD.value, Optimizer.ADAM_W.value]
[docs]
@patch("florist.api.clients.optimizers.torch")
def test_optimizer_get_sgd(mock_torch: Mock):
    test_model_parameters = Mock()
    test_optimizer = "test-optimizer"
    mock_torch.optim.SGD.return_value = test_optimizer
    optimizer = Optimizer.get(Optimizer.SGD, test_model_parameters)
    assert optimizer == test_optimizer
    mock_torch.optim.SGD.assert_called_with(test_model_parameters, lr=0.001, momentum=0.9)
[docs]
@patch("florist.api.clients.optimizers.torch")
def test_optimizer_get_adam_w(mock_torch: Mock):
    test_model_parameters = Mock()
    test_optimizer = "test-optimizer"
    mock_torch.optim.AdamW.return_value = test_optimizer
    optimizer = Optimizer.get(Optimizer.ADAM_W, test_model_parameters)
    assert optimizer == test_optimizer
    mock_torch.optim.AdamW.assert_called_with(test_model_parameters, lr=0.01)