"""This module contains the ``BaseArguments`` class."""
from __future__ import annotations
import warnings
from collections.abc import Iterator, Mapping
from dataclasses import dataclass, fields
from inspect import Parameter
from typing import TYPE_CHECKING, Any
from tno.quantum.utils._utils import get_init_arguments_info
from tno.quantum.utils.serialization import Serializable
if TYPE_CHECKING:
    from typing import Self
[docs]
@dataclass
class BaseArguments(Mapping[str, Any], Serializable):
    r'''Base class for argument classes.
    Example:
        >>> from dataclasses import dataclass
        >>> from tno.quantum.utils import BaseArguments
        >>>
        >>> @dataclass
        ... class ExampleArguments(BaseArguments):
        ...     """
        ...     Attributes:
        ...         name: attribute description name
        ...         size: attribute description size
        ...     """
        ...     name: str = "test-name"
        ...     size: int = 5
        >>>
        >>> args = ExampleArguments.from_mapping({ "name": "example", "size": 42 })
        >>> args.name
        'example'
    '''
[docs]
    def __getitem__(self, key: str) -> Any:
        """Retrieve attribute item by key.
        Args:
            key: Key to retrieve.
        Returns:
            The value associated with the key.
        Raises:
            KeyError: If the key is not found.
        """
        if hasattr(self, key):
            return getattr(self, key)
        error_msg = f"`{key}` not found in Arguments."
        raise KeyError(error_msg) 
[docs]
    def __iter__(self) -> Iterator[str]:
        """Iterate over the keys in the instance corresponding to known attributes."""
        return iter(field.name for field in fields(self)) 
[docs]
    def __len__(self) -> int:
        """Return the number of known attribute keys."""
        return len(fields(self)) 
[docs]
    @classmethod
    def from_mapping(cls, data: Mapping[str, Any]) -> Self:
        """Create an instance from a mapping.
        Args:
            data: Mapping containing key-value pairs to store in an arguments object.
                keys that are not recognized as attributes will be ignored. If a known
                argument is missing but has a default value, default value will be used.
        Returns:
            Instance of arguments.
        Raises:
            KeyError: If the data does not contain all the attribute keys for which no
                default value is known.
            UserWarning: If data contains keys that are not recognized as attributes.
        """
        if isinstance(data, cls):
            return data
        init_args_info = get_init_arguments_info(cls)
        extra_args = set(data.keys()) - set(init_args_info.keys())
        if extra_args:
            warnings.warn(
                f"Ignoring unknown keys: {', '.join(extra_args)}",
                UserWarning,
                stacklevel=2,
            )
        args: dict[str, Any] = {
            arg_name: arg_value
            for arg_name, arg_value in data.items()
            if arg_name in init_args_info
        }
        # Check missing required arguments
        for arg_name, default_value in init_args_info.items():
            if arg_name not in data and default_value is Parameter.empty:
                error_msg = f"Missing required key: {arg_name}"
                raise KeyError(error_msg)
        return cls(**args)