Source code for tno.quantum.utils.validation

"""This module contains generic validation methods."""

from __future__ import annotations

import importlib.util
import os
import re
import warnings
from collections.abc import Mapping
from copy import deepcopy
from datetime import timedelta
from numbers import Integral, Real
from pathlib import Path
from typing import TYPE_CHECKING, Any, SupportsFloat, TypeVar

import numpy as np
from numpy.random import RandomState
from numpy.typing import NDArray

TYPE_BOUNDS = TypeVar("TYPE_BOUNDS", float, int, timedelta)
TYPE_INSTANCE = TypeVar("TYPE_INSTANCE")

if TYPE_CHECKING:
    from numpy.typing import NDArray

# ruff: noqa:  PLR0913


[docs]def check_real( x: Any, name: str, *, l_bound: SupportsFloat | None = None, u_bound: SupportsFloat | None = None, l_inclusive: bool = True, u_inclusive: bool = True, ) -> float: """Check if the variable `x` with name `name` is a real number. Optionally, lower and upper bounds can also be checked. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. l_bound: Lower bound of `x`. u_bound: Upper bound of `x`. l_inclusive: If ``True`` the lower bound is inclusive, otherwise the lower bound is exclusive. u_inclusive: If ``True`` the upper bound is inclusive, otherwise the upper bound is exclusive. Raises: TypeError: If `x` is not a real number. ValueError: If `x` is outside the give bounds. Returns: Floating point representation of `x`. """ if not isinstance(x, Real): error_msg = f"'{name}' should be a real number, but was of type {type(x)}." raise TypeError(error_msg) x_float = float(x) if l_bound is not None: check_lower_bound(x_float, name, float(l_bound), inclusive=l_inclusive) if u_bound is not None: check_upper_bound(x_float, name, float(u_bound), inclusive=u_inclusive) return x_float
[docs]def check_int( x: Any, name: str, *, l_bound: SupportsFloat | None = None, u_bound: SupportsFloat | None = None, l_inclusive: bool = True, u_inclusive: bool = True, ) -> int: """Check if the variable `x` with name `name` is an integer. Optionally, lower and upper bounds can also be checked. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. l_bound: Lower bound of `x`. u_bound: Upper bound of `x`. l_inclusive: If ``True`` the lower bound is inclusive, otherwise the lower bound is exclusive. u_inclusive: If ``True`` the upper bound is inclusive, otherwise the upper bound is exclusive. Raises: TypeError: If `x` is not an integer. ValueError: If `x` is outside the give bounds. Returns: Integer representation of `x`. """ if not isinstance(x, Real): error_msg = f"'{name}' should be an integer, but was of type {type(x)}" raise TypeError(error_msg) int_x: int = int(x) # type: ignore[call-overload] if not isinstance(x, Integral) and x - int_x != 0: msg = f"'{name}' with value {x} could not be safely converted to an integer" raise ValueError(msg) if l_bound is not None: check_lower_bound(int_x, name, float(l_bound), inclusive=l_inclusive) if u_bound is not None: check_upper_bound(int_x, name, float(u_bound), inclusive=u_inclusive) return int_x
[docs]def check_lower_bound( x: TYPE_BOUNDS, name: str, l_bound: TYPE_BOUNDS, *, inclusive: bool ) -> None: """Check if the variable `x` with name `name` satisfies a lower bound. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. l_bound: Lower bound of `x`. inclusive: If ``True`` the lower bound is inclusive, otherwise the lower bound is exclusive. Raises: ValueError: If `x` is outside the give bounds. """ error_msg = f"'{name}" + "' has an {} lower bound of {}" + f", but was {x!s}." if inclusive and x < l_bound: raise ValueError(error_msg.format("inclusive", l_bound)) if not inclusive and x <= l_bound: raise ValueError(error_msg.format("exclusive", l_bound))
[docs]def check_upper_bound( x: TYPE_BOUNDS, name: str, u_bound: TYPE_BOUNDS, *, inclusive: bool ) -> None: """Check if the variable `x` with name `name` satisfies an upper bound. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. u_bound: Upper bound of `x`. inclusive: If ``True`` the lower bound is inclusive, otherwise the lower bound is exclusive. Raises: ValueError: If `x` is outside the give bounds. """ error_msg = f"'{name}" + "' has an {} upper bound of {}" + f", but was {x!s}." if inclusive and x > u_bound: raise ValueError(error_msg.format("inclusive", u_bound)) if not inclusive and x >= u_bound: raise ValueError(error_msg.format("exclusive", u_bound))
[docs]def check_string( x: Any, name: str, *, lower: bool = False, upper: bool = False, ) -> str: """Check if the variable `x` with name `name` is a string. Optionally, the string can be converted to all lowercase or all uppercase letters. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. lower: If ``True``, `x` will be returned with lowercase letters. Defaults to ``False``. upper: If ``True``, `x` will be returned with uppercase letters. Default to ``False``. Raises: TypeError: If `x` is not an instance of :py:const:`str`. Returns: Input string. Optionally, in lowercase or uppercase letters. """ if not isinstance(x, str): error_msg = f"'{name}' must be a string, but was of type {type(x)}" raise TypeError(error_msg) if lower: x = x.lower() if upper: x = x.upper() return x
[docs]def check_snake_case( x: Any, name: str, *, path: bool = False, warn: bool = False ) -> str: """Check if the variable `x` with name `name` is a string in snake case convention. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. path: If ``True``, treats the name as a path variable with periods. Each substring separated by a period must be in valid snake case convention. Defaults to ``False``. warn: If ``True``, issue a warning instead of raising an ValueError. Defaults to ``False`` Raises: TypeError: If `x` is not an instance of :py:const:`str`. ValueError: If `x` is not in snake case and `warn` is ``False``. Returns: Input string. """ y = check_string(x, name) if path: # allowing periods snake_case_pattern = re.compile( r"^[a-z][a-z0-9]*(_[a-z0-9]+|\.[a-z][a-z0-9]*)*$" ) else: snake_case_pattern = re.compile(r"^[a-z][a-z0-9]*(_[a-z0-9]+)*$") if not snake_case_pattern.match(y): if warn: warn_msg = f"'{name}' is not in snake case convention, but was {y}." warnings.warn(warn_msg, stacklevel=2) else: error_msg = f"'{name}' must be in snake case convention, but was '{y}'" raise ValueError(error_msg) return y
[docs]def check_python_variable(x: Any, name: str, *, warn: bool = False) -> str: """Check if variable `x` with name `name` is a string in Python variable convention. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. warn: If ``True``, issue a warning instead of raising an ValueError. Defaults to ``False`` Raises: TypeError: If `x` is not an instance of :py:const:`str`. ValueError: If `x` is not in Python variable convention and `warn` is ``False``. Returns: Input string. """ y = check_string(x, name) python_variable_pattern = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") if not python_variable_pattern.match(y): if warn: warn_msg = f"'{name}' is not in Python variable convention, but was {y}." warnings.warn(warn_msg, stacklevel=2) else: error_msg = f"'{name}' must be in Python variable convention, but was '{y}'" raise ValueError(error_msg) return y
[docs]def check_bool(x: Any, name: str, *, safe: bool = False) -> bool: """Check if the variable `x` with name `name` is a boolean value. Optionally, cast to boolean value if it is not. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. safe: If ``True`` raise a `TypeError` when `x` is not a bool. If ``False`` cast to bool. Raises: TypeError: If `safe` is ``True`` and `x` is not a boolean value. Returns: Boolean representation of the input. """ if not isinstance(x, bool) and safe: error_msg = f"'{name}' must be a boolean value, but was of type {type(x)}" raise TypeError(error_msg) return bool(x)
[docs]def check_binary(x: Any, name: str) -> int: """Check if the variable `x` with name `name` is a binary variable. Will casts the variable to int representation. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. Raises: TypeError: If `x` is a string but does not have a binary value. ValueError: If `x` could not safely be converted to the integer 0 or 1. Returns: Binary int representation of the input. """ if isinstance(x, str): if x in ("0", "1"): return int(x) msg = f"'{name}' must be a Binary variable, but was of type {type(x)}" raise TypeError(msg) return check_int(x, name, l_bound=0, u_bound=1)
[docs]def check_kwarglike(x: Any, name: str, *, safe: bool = False) -> dict[str, Any]: """Check if the variable `x` with name `name` is a kwarglike object. A object is kwarglike if it is a mapping where all keys are strings. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. safe: If ``True``, makes a deep copy of all values of `x`. Raises: TypeError: If `x` is not a Mapping. KeyError: If `x` has at least one key that is not a string. Returns: Dictionary with the key-value pairs from `x`. """ if not isinstance(x, Mapping): error_msg = ( f"'{name}' must be an instance of <class 'Mapping'>," f" but was of type {type(x)}" ) raise TypeError(error_msg) x_dict = {} for key, value in x.items(): if not isinstance(key, str): error_msg = f"At least one key in '{name}' is not a string" raise KeyError(error_msg) x_dict[key] = deepcopy(value) if safe else value return x_dict
[docs]def check_arraylike( x: Any, name: str, *, ndim: int | None = None, shape: tuple[int, ...] | None = None ) -> NDArray[Any]: """Check if the variable `x` with name `name` is an ArrayLike. Optionally, check if the result has the specified number of dimensions. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. ndim: Number of dimensions `x` should have. When `ndim = 1`, arrays with more than one dimension will be squeezed to remove dimensions of size one. shape: Shape `x` should have. Raises: ValueError: When the provided input is not ArrayLike or does not have the correct shape or number of dimensions. Returns: NDArray representation of `x`. """ array: NDArray[Any] array = x.toarray() if hasattr(x, "toarray") else np.asarray(x) if ndim is not None: if ndim == 1 and array.ndim != 1: array = array.squeeze() if array.ndim != ndim: msg = f"'{name}' must be an ArrayLike with {ndim} dimension(s), but had " msg += f"{array.ndim} dimension(s)." raise ValueError(msg) if shape is not None and array.shape != shape: msg = f"'{name}' must be an ArrayLike of shape {shape}, but had " msg += f"shape {array.shape}." raise ValueError(msg) return array
[docs]def check_instance(x: TYPE_INSTANCE, name: str, dtype: type) -> TYPE_INSTANCE: """Check if `x` with name `name` is an instance of dtype. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. dtype: the type of variable to validate against. Raises: TypeError: If `x` is not an instance of `dtype`. Returns: The input `x` if it is an instance of `dtype`. """ if not isinstance(_ := x, dtype): # `_ := x` instead of `x` to fix mypy msg = f"'{name}' must be an instance of {dtype}, but was of type {type(x)}" raise TypeError(msg) return x
[docs]def check_path( x: Any, name: str, *, must_exist: bool = False, must_be_file: bool = False, must_be_dir: bool = False, required_suffix: str | None = None, safe: bool = False, ) -> Path: """Check if the variable `path` with name `name` is a valid path. Optionally, existence, file, and directory checks can also be performed. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. required_suffix: If specified, the path must have this suffix. must_exist: If ``True``, the path must exist. must_be_file: If ``True``, the path must be a file. must_be_dir: If ``True``, the path must be a directory. safe: If ``True`` and the path does not have the required suffix a ``ValueError`` is raised. Otherwise, if ``False``, the suffix will be replaced to match the required_suffix. Raises: TypeError: If `path` is not a string or Path object. ValueError: If `path` does not have the correct required_suffix. OSError: If `path` does not exist while must_exist is ``True`` FileNotFoundError: If `path` is not a file while must_be_file is ``True`` NotADirectoryError: If `path` is not a directory while must_be_dir is ``True`` Returns: Path object representing `path`. """ if not isinstance(x, (str, os.PathLike)): error_msg = ( f"'{name}' should be a string or os.PathLike object, " f"but was of type {type(x)}." ) raise TypeError(error_msg) path_obj = Path(x) if required_suffix: if path_obj.suffix != required_suffix and safe: error_msg = ( f"The path `{path_obj}` does not have the required suffix " f"`{required_suffix}`." ) raise ValueError(error_msg) path_obj = path_obj.with_suffix(required_suffix) if must_exist and not path_obj.exists(): error_msg = f"The path `{path_obj}` does not exist." raise OSError(error_msg) if must_be_file and not path_obj.is_file(): error_msg = f"The path `{path_obj}` is not a file." raise FileNotFoundError(error_msg) if must_be_dir and not path_obj.is_dir(): error_msg = f"The path `{path_obj}` is not a directory." raise NotADirectoryError(error_msg) return path_obj
[docs]def check_timedelta( x: Any, name: str, *, l_bound: SupportsFloat | timedelta | None = None, u_bound: SupportsFloat | timedelta | None = None, l_inclusive: bool = True, u_inclusive: bool = True, ) -> timedelta: """Check if the variable `x` with name `name` can be converted to a timedelta. Optionally, lower and upper bounds can also be checked. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. l_bound: Lower bound of `x`. u_bound: Upper bound of `x`. l_inclusive: If ``True`` the lower bound is inclusive, otherwise the lower bound is exclusive. u_inclusive: If ``True`` the upper bound is inclusive, otherwise the upper bound is exclusive. Raises: TypeError: If `x` cannot be converted to a timedelta. ValueError: If `x` is outside the given bounds. Returns: Timedelta representation of `x`. """ if not isinstance(x, (Real, timedelta)): error_msg = ( f"'{name}' should be a real or a timedelta, but was of type {type(x)}." ) raise TypeError(error_msg) if isinstance(x, timedelta): td = x elif isinstance(x, Real): td = timedelta(seconds=float(x)) if l_bound is not None: l_bound_ = check_timedelta(l_bound, name="l_bound") check_lower_bound(td, name, l_bound_, inclusive=l_inclusive) if u_bound is not None: u_bound_ = check_timedelta(u_bound, name="u_bound") check_upper_bound(td, name, u_bound_, inclusive=u_inclusive) return td
[docs]def check_random_state( x: Any, name: str, ) -> RandomState: """Check if the variable `x` with name `name` can be converted to a :py:class:`~numpy.random.RandomState`. If `x` is already a :py:class:`~numpy.random.RandomState` instance, return it. If `x` is an integer, return a new :py:class:`~numpy.random.RandomState` instance seeded with `x`. If `x` is ``None``, return a new unseeded :py:class:`~numpy.random.RandomState` instance. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. Raises: TypeError: If `x` is not an instance of ``None``, :py:class:`~numbers.Integral` or :py:class:`~numpy.random.RandomState`. """ # noqa: E501 if x is None: return RandomState() if isinstance(x, Integral): return RandomState(int(x)) if isinstance(x, RandomState): return x error_msg = ( f"'{name}' should be a ``RandomState``, integer or ``None``, " f"but was of type {type(x)}." ) raise TypeError(error_msg)
if importlib.util.find_spec("matplotlib") is not None: import matplotlib.pyplot as plt from matplotlib.axes import Axes
[docs] def check_ax( x: Any, name: str, ) -> Axes: """Check if the variable `x` with name `name` can be converted to an :py:class:`~matplotlib.axes.Axes` object. If `x` is already a :py:class:`~matplotlib.axes.Axes` instance, return it. If `x` is ``None``, return a new :py:class:`~matplotlib.axes.Axes` instance. Args: x: Variable to check. name: Name of the variable. This name will be displayed in possible error messages. Returns: Parsed ax. Raises: TypeError: If `x` is not ``None`` or an instance of :py:class:`~matplotlib.axes.Axes`. """ # noqa: E501 if isinstance(x, Axes): return x if x is None: _, ax = plt.subplots() return ax error_msg = f"'{name}' should be a `Axes` or `None`, but was of type {type(x)}." raise TypeError(error_msg)
[docs]def warn_if_positive(x: SupportsFloat, name: str) -> None: """Give a warning when `x` is positive. Args: x: Variable to check. name: Name of the variable. This name will be displayed in the warning. """ if float(x) > 0: warn_msg = f"'{name}' was positive" warnings.warn(warn_msg, stacklevel=2)
[docs]def warn_if_negative(x: SupportsFloat, name: str) -> None: """Give a warning when `x` is negative. Args: x: Variable to check. name: Name of the variable. This name will be displayed in the warning. """ if float(x) < 0: warn_message = f"'{name}' was negative" warnings.warn(warn_message, stacklevel=2)