Source code for tno.quantum.utils._optimizer_config

"""This module contains the ``OptimizerConfig`` class."""

# ruff: noqa: PLC0415

from __future__ import annotations

import importlib
import importlib.util
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any

from tno.quantum.utils._base_config import BaseConfig

if importlib.util.find_spec("torch") is not None:
    from torch.optim.optimizer import Optimizer

[docs] @dataclass(init=False) class OptimizerConfig(BaseConfig[Optimizer]): """Configuration class for creating instances of a PyTorch optimizer. Currently only a selection of PyTorch optimizers are supported. See the documentation of :py:meth:`~OptimizerConfig.supported_items` for information on which optimizers are supported. Example: >>> import torch >>> from tno.quantum.utils import OptimizerConfig >>> >>> # List all supported optimizers >>> list(OptimizerConfig.supported_items()) ['adagrad', 'adam', 'rprop', 'stochastic_gradient_descent'] >>> >>> # Instantiate an optimizer >>> config = OptimizerConfig(name="adagrad", options={"lr": 0.5}) >>> type(config.get_instance(params=[torch.rand(1)])) <class 'torch.optim.adagrad.Adagrad'> """
[docs] def __init__(self, name: str, options: Mapping[str, Any] | None = None) -> None: """Init :py:class:`OptimizerConfig`. Args: name: Name of the :py:class:`torch.optim.optimizer.Optimizer` class. options: Keyword arguments to be passed to the optimizer. Must be a mapping-like object keys being string objects. Values can be anything depending on specific optimizer. Raises: TypeError: If `name` is not a string or `options` is not a mapping. KeyError: If `options` has a key that is not a string. KeyError: If `name` does not match any of the supported optimizers. """ super().__init__(name=name, options=options)
[docs] @staticmethod def supported_items() -> dict[str, type[Optimizer]]: """Obtain supported PyTorch optimizers. If PyTorch is installed then the following optimizers are supported: - Adagrad - name: ``"adagrad"`` - options: see `Adagrad kwargs`__ - Adam - name: ``"adam"`` - options: see `Adam kwargs`__ - Rprop - name: ``"rprop"`` - options: see `Rprop kwargs`__ - SDG: - name: ``"stochastic_gradient_descent"`` - options: see `SDG kwargs`__ __ https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html __ https://pytorch.org/docs/stable/generated/torch.optim.Adam.html __ https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html __ https://pytorch.org/docs/stable/generated/torch.optim.SGD.html Raises: ModuleNotFoundError: If PyTorch can not be detected and no optimizers can be found. Returns: Dictionary with supported optimizers by their name. """ try: from torch.optim.adagrad import Adagrad from torch.optim.adam import Adam from torch.optim.rprop import Rprop from torch.optim.sgd import SGD except ModuleNotFoundError as exception: msg = "Torch can't be detected and hence no optimizers can be found." raise ModuleNotFoundError(msg) from exception else: return { "adagrad": Adagrad, "adam": Adam, "rprop": Rprop, "stochastic_gradient_descent": SGD, }