"""This module contains utility functions."""
import importlib.util
import inspect
import re
from typing import Any
from tno.quantum.utils.validation import check_string
if importlib.util.find_spec("numpy") is not None:
import numpy as np
[docs]
def convert_to_snake_case(x: str, *, path: bool = False) -> str:
"""Convert string to snake case.
Args:
x: String to convert.
path: If ``True``, treats the variable as a path variable with periods. Each
substring separated by a period will be converted to a valid snake case
convention. Defaults to ``False``.
Raises:
TypeError: If `x` is not an instance of :py:const:`str`.
ValueError: If the input cannot be converted to snake case because it starts
with an invalid character (anything other than a letter).
Returns:
Snake case variant of `x`.
"""
x = check_string(x, "x")
if path:
substrings = [
convert_to_snake_case(substring, path=False) for substring in x.split(".")
]
return ".".join(substrings).lower()
if not re.match(r"^[a-zA-Z]", x):
error_msg = "Input cannot start with a number or any special symbol."
raise ValueError(error_msg)
if re.search(r"[^a-zA-Z0-9 \-_]", x):
error_msg = "Input cannot contain special characters."
raise ValueError(error_msg)
# Convert x to snake_case convention
words = x.replace("-", " ").split()
words = [re.sub(r"([A-Z]+)", r" \1", word) for word in words]
words = [re.sub(r"([A-Z][a-z]+)", r" \1", word) for word in words]
words = [word.strip() for word in words]
return "_".join(words).replace(" ", "_").replace("__", "_").lower()
[docs]
def get_installed_subclasses(module_name: str, subclass: Any) -> dict[str, type[Any]]:
"""Obtain all installed subclasses within a module.
Args:
module_name: Name of the module to search.
subclass: The subclass to search for.
Returns:
Dictionary with subclasses by their snake-case name.
"""
supported_subclasses = {}
module = importlib.import_module(module_name)
for name in dir(module):
obj = getattr(module, name)
# Determine if object is subclass of the to search for class.
if inspect.isclass(obj):
mro = inspect.getmro(obj)
else:
mro = inspect.getmro(obj.__class__)
if any(issubclass(cls, subclass) for cls in mro):
supported_subclasses[convert_to_snake_case(name)] = obj
return supported_subclasses
def get_init_arguments_info(cls: type[Any]) -> dict[str, Any]:
"""Retrieve names, and default values of ``__init__`` arguments for a given class.
Args:
cls: The class to inspect.
Returns:
A dictionary where the keys are argument names and the values are the parameter
default values (if any, otherwise ``Parameter.empty``).
"""
init_signature = inspect.signature(cls.__init__)
init_args = {}
for param in init_signature.parameters.values():
if param.name == "self":
continue
init_args[param.name] = param.default
return init_args
[docs]
def check_equal(first: Any, second: Any) -> bool: # noqa: PLR0911
"""Check if two objects are equal.
Equality check if applied recursively on lists, tuples, dictionaries and NumPy
arrays. That is, such objects are considered equal if all their elements are equal.
Args:
first: First object to compare.
second: Second object to compare.
Returns:
True if objects are equal, otherwise false.
"""
if type(first) is not type(second):
return False
if isinstance(first, dict):
if len(first) != len(second):
return False
for key in first:
if key not in second:
return False
if not check_equal(first[key], second[key]):
return False
return True
if isinstance(first, (list, tuple)):
if len(first) != len(second):
return False
return all(check_equal(x, y) for x, y in zip(first, second, strict=True))
if importlib.util.find_spec("numpy") is not None and isinstance(first, np.ndarray):
return np.array_equal(first, second)
return bool(first == second)