Source code for tno.quantum.utils._base_arguments

"""This module contains the ``BaseArguments`` class."""

from __future__ import annotations

import warnings
from collections.abc import Iterator, Mapping
from dataclasses import dataclass, fields
from inspect import Parameter
from typing import TYPE_CHECKING, Any

from tno.quantum.utils._utils import get_init_arguments_info
from tno.quantum.utils.serialization import Serializable

if TYPE_CHECKING:
    from typing import Self


[docs]@dataclass class BaseArguments(Mapping[str, Any], Serializable): r'''Base class for argument classes. Example: >>> from dataclasses import dataclass >>> from tno.quantum.utils import BaseArguments >>> >>> @dataclass ... class ExampleArguments(BaseArguments): ... """ ... Attributes: ... name: attribute description name ... size: attribute description size ... """ ... name: str = "test-name" ... size: int = 5 >>> >>> args = ExampleArguments.from_mapping({ "name": "example", "size": 42 }) >>> args.name 'example' '''
[docs] def __getitem__(self, key: str) -> Any: """Retrieve attribute item by key. Args: key: Key to retrieve. Returns: The value associated with the key. Raises: KeyError: If the key is not found. """ if hasattr(self, key): return getattr(self, key) error_msg = f"`{key}` not found in Arguments." raise KeyError(error_msg)
[docs] def __iter__(self) -> Iterator[str]: """Iterate over the keys in the instance corresponding to known attributes.""" return iter(field.name for field in fields(self))
[docs] def __len__(self) -> int: """Return the number of known attribute keys.""" return len(fields(self))
[docs] @classmethod def from_mapping(cls, data: Mapping[str, Any]) -> Self: """Create an instance from a mapping. Args: data: Mapping containing key-value pairs to store in an arguments object. keys that are not recognized as attributes will be ignored. If a known argument is missing but has a default value, default value will be used. Returns: Instance of arguments. Raises: KeyError: If the data does not contain all the attribute keys for which no default value is known. UserWarning: If data contains keys that are not recognized as attributes. """ if isinstance(data, cls): return data init_args_info = get_init_arguments_info(cls) extra_args = set(data.keys()) - set(init_args_info.keys()) if extra_args: warnings.warn( f"Ignoring unknown keys: {', '.join(extra_args)}", UserWarning, stacklevel=2, ) args: dict[str, Any] = { arg_name: arg_value for arg_name, arg_value in data.items() if arg_name in init_args_info } # Check missing required arguments for arg_name, default_value in init_args_info.items(): if arg_name not in data and default_value is Parameter.empty: error_msg = f"Missing required key: {arg_name}" raise KeyError(error_msg) return cls(**args)