Source code for florist.api.clients.optimizers
"""Definitions for the optimizers that can be used."""
from enum import Enum
from typing import Iterator
import torch
from typing_extensions import Self
[docs]
class Optimizer(Enum):
    """Enumeration of pre-defined optimizers."""
    SGD = "SGD"
    ADAM_W = "AdamW"
[docs]
    @classmethod
    def get(cls, optimizer: Self, model_parameters: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:  # type: ignore
        """
        Return an instance for the given optimizer with model parameters.
        :param optimizer: (Optimizer) The optimizer type to get and instance of.
        :param model_parameters: (Iterator[torch.nn.parameter.Parameter]) The parameters of the model
            that will be set to the optimizer.
        :return: (torch.optim.Optimizer) An instance of the optimizer.
        """
        if optimizer == Optimizer.SGD:
            return torch.optim.SGD(model_parameters, lr=0.001, momentum=0.9)  # type: ignore
        if optimizer == Optimizer.ADAM_W:
            return torch.optim.AdamW(model_parameters, lr=0.01)  # type: ignore
        raise ValueError(f"Optimizer {optimizer} not supported.")
[docs]
    @classmethod
    def list(cls) -> list[str]:
        """
        List all the supported optimizers.
        :return: (list[str]) a list of supported optimizers.
        """
        return [optimizer.value for optimizer in Optimizer]