Source code for florist.api.clients.optimizers

"""Definitions for the optimizers that can be used."""

from enum import Enum
from typing import Iterator

import torch
from typing_extensions import Self


[docs] class Optimizer(Enum): """Enumeration of pre-defined optimizers.""" SGD = "SGD" ADAM_W = "AdamW"
[docs] @classmethod def get(cls, optimizer: Self, model_parameters: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer: # type: ignore """ Return an instance for the given optimizer with model parameters. :param optimizer: (Optimizer) The optimizer type to get and instance of. :param model_parameters: (Iterator[torch.nn.parameter.Parameter]) The parameters of the model that will be set to the optimizer. :return: (torch.optim.Optimizer) An instance of the optimizer. """ if optimizer == Optimizer.SGD: return torch.optim.SGD(model_parameters, lr=0.001, momentum=0.9) # type: ignore if optimizer == Optimizer.ADAM_W: return torch.optim.AdamW(model_parameters, lr=0.01) # type: ignore raise ValueError(f"Optimizer {optimizer} not supported.")
[docs] @classmethod def list(cls) -> list[str]: """ List all the supported optimizers. :return: (list[str]) a list of supported optimizers. """ return [optimizer.value for optimizer in Optimizer]