"""This module contains the :py:class:`Serializable` class.
The module provides the tools to give any class support for default serialization and
deserialization.
For example, when defining a class that inherits from :py:class:`Serializable`, methods
such as :py:meth:`~Serializable.to_json` and :py:meth:`~Serializable.from_json` are
automatically added to the class.
>>> from tno.quantum.utils.serialization import Serializable
>>>
>>> # Define a class that inherits from Serializable
>>> class Point(Serializable):
... def __init__(self, x, y):
... self.x = x
... self.y = y
...
... def __repr__(self):
... return f"Point(x={self.x}, y={self.y})"
>>>
>>> # Create an instance and serialize to JSON
>>> point = Point(1, 2)
>>> point_json = point.to_json()
>>> print(point_json) # doctest: +ELLIPSIS
{"x": 1, "y": 2, "__class__": "..."}
>>>
>>> # Deserialize JSON to instance of Point
>>> Point.from_json(point_json) # doctest: +SKIP
Point(x=1, y=2)
.. note::
By default, serialization is performed by storing the attributes of a class as found
in the :py:const:`__init__` of the class. Therefore, make sure your class has
attributes that match the signature of the constructor of your class.
Information about the class is stored in the serialized data, so that it is possible to
deserialize without knowing the target class. For example:
>>> Serializable.from_json(point_json) # doctest: +SKIP
Point(x=1, y=2)
If your class contains attributes which are instances of third-party classes, it is
possible to register serialization and deserialization methods for these third-party
classes. For example, (de)serialization methods for NumPy arrays are registered (by
default) as follows:
>>> import numpy as np
>>> from numpy.typing import NDArray
>>>
>>> def serialize_ndarray(value: NDArray[Any]) -> dict[str, Any]:
... return {"dtype": str(value.dtype), "array": value.tolist()}
>>>
>>> def deserialize_ndarray(data: dict[str, Any]) -> NDArray[Any]:
... dtype = np.dtype(data["dtype"])
... array = data["array"]
... return np.array(array, dtype=dtype)
>>>
>>> Serializable.register(np.ndarray, serialize_ndarray, deserialize_ndarray)
NumPy arrays will now automatically be (de)serialized, as in the following example:
>>> class LinearSystem(Serializable):
... matrix: NDArray[np.float64]
... vector: NDArray[np.float64]
...
... def __init__(self, matrix, vector):
... self.matrix = matrix
... self.vector = vector
...
... def __repr__(self):
... return f"LinearSystem(matrix={self.matrix}, vector={self.vector})"
>>>
>>> system = LinearSystem(np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([1.0, 1.0]))
>>> system_json = system.to_json()
>>> print(system_json) # doctest: +ELLIPSIS
{"matrix": {"dtype": "float64", "array": [[1.0, 2.0], [3.0, 4.0]], "__class__": "numpy.ndarray"}, "vector": {"dtype": "float64", "array": [1.0, 1.0], "__class__": "numpy.ndarray"}, "__class__": "..."}
>>>
>>> Serializable.from_json(system_json) # doctest: +SKIP
LinearSystem(matrix=[[1. 2.]
[3. 4.]], vector=[1. 1.])
.. note::
Instances of classes are (de)serialized to and from :py:const:`dict`'s via the
private methods :py:const:`Serializable._serialize(self) -> dict[str, Any]` and
:py:const:`Serializable._deserialize(cls, data: dict[str, Any]) -> Any`. If your
class requires special (de)serialization, it is possible to override these
methods in your class. These must be implemented in such a way that they are
inverse to each other.
""" # noqa: E501
from __future__ import annotations
import importlib
import importlib.util
import inspect
import json
import tempfile
import warnings
from collections.abc import Callable
from datetime import timedelta
from pathlib import Path
from typing import Any
import yaml
from tno.quantum.utils.validation import check_path
_externals: dict[
str, tuple[Callable[[Any], dict[str, Any]], Callable[[dict[str, Any]], Any]]
] = {}
[docs]class Serializable:
"""Framework for serializable objects."""
[docs] def to_json(self, *, indent: int | None = None) -> str:
"""Serialize to JSON.
Args:
indent: If provided, JSON will be pretty-printed with given indent level.
Returns:
JSON string.
"""
return json.dumps(self.serialize(), indent=indent)
[docs] def to_yaml(self) -> str:
"""Serialize to YAML.
Returns:
YAML string.
"""
return yaml.dump(self.serialize(), Dumper=yaml.Dumper)
[docs] def to_json_file(self, path: str | Path, *, indent: int | None = None) -> None:
"""Serialize and write to JSON file.
Args:
path: Path of the file to write to.
indent: If provided, JSON will be pretty-printed with given indent level.
"""
path = check_path(
path,
"path",
must_exist=False,
required_suffix=".json",
safe=True,
)
with path.open("w", encoding="utf-8") as file:
json.dump(self.serialize(), file, indent=indent)
[docs] def to_yaml_file(self, path: str | Path) -> None:
"""Serialize and write to YAML file.
Args:
path: Path of the file to write to.
"""
path = check_path(
path,
"path",
must_exist=False,
required_suffix=".yaml",
safe=True,
)
with path.open("w", encoding="utf-8") as file:
yaml.dump(self.serialize(), file)
[docs] def serialize(self: Any) -> Any:
"""Serialize self to dict, list or primitive.
Returns:
Representation of self as dict, list, string, boolean or ``None``.
"""
value = self
# bool, str, None, int, float
if type(value) in [bool, str, type(None), int, float]:
return value
# list
if type(value) is list:
return [Serializable.serialize(x) for x in value]
# tuple
if type(value) is tuple:
return {"__tuple__": [Serializable.serialize(x) for x in value]}
# dict
if type(value) is dict:
dict_ = {}
for key, val in value.items():
if not isinstance(key, str):
msg = f"Could not serialize dict with key of type {type(key)}"
raise NotImplementedError(msg)
dict_[key] = Serializable.serialize(val)
return dict_
# Serializable
if isinstance(value, Serializable):
dict_ = value._serialize() # noqa: SLF001
dict_["__class__"] = Serializable._class_to_path(value.__class__)
return dict_
# External
if external := Serializable._get_external(value.__class__):
dict_ = external[0](value)
dict_["__class__"] = Serializable._class_to_path(value.__class__)
return dict_
msg = f"Could not serialize value of type {type(value)}"
raise NotImplementedError(msg)
def _serialize(self) -> dict[str, Any]:
"""Serialize to dict.
Classes derived from ``Serializable`` may override this method for custom
serialization. In this case, override ``_deserialize`` accordingly.
"""
init_signature = inspect.signature(self.__class__.__init__)
init_args = [parameter.name for parameter in init_signature.parameters.values()]
init_args = init_args[1:] # remove first argument `self`
dict_ = {}
for key in init_args:
if not hasattr(self, key):
msg = (
f"Failed to serialize value of type {type(self)}: missing "
f"attribute '{key}' which is expected by __init__ of {type(self)}"
)
raise ValueError(msg)
dict_[key] = Serializable.serialize(getattr(self, key))
return dict_
[docs] @classmethod
def from_json(cls, data: str) -> Any:
"""Deserialize from JSON.
Args:
data: JSON string to deserialize.
Returns:
Deserialized instance of `cls`.
Raises:
ValueError: If `data` is ill-formed.
NotImplementedError: If no deserialization method exists to deserialize.
"""
return Serializable._deserialize_class(cls, json.loads(data))
[docs] @classmethod
def from_yaml(cls, data: str) -> Any:
"""Deserialize from YAML.
Args:
data: YAML string to deserialize.
Returns:
Deserialized instance of `cls`.
Raises:
ValueError: If `data` is ill-formed.
NotImplementedError: If no deserialization method exists to deserialize.
"""
return Serializable._deserialize_class(cls, yaml.safe_load(data))
[docs] @classmethod
def from_json_file(cls, path: str | Path) -> Any:
"""Read and deserialize from JSON file.
Args:
path: Path to JSON file to deserialize.
Returns:
Deserialized instance of `cls`.
Raises:
ValueError: If `data` is ill-formed.
NotImplementedError: If no deserialization method exists to deserialize.
FileNotFoundError: If file at `path` not found.
"""
path = check_path(
path,
"path",
must_exist=True,
must_be_file=True,
required_suffix=".json",
safe=True,
)
with path.open("r", encoding="utf-8") as file:
return Serializable._deserialize_class(cls, json.load(file))
[docs] @classmethod
def from_yaml_file(cls, path: str | Path) -> Any:
"""Read and deserialize from YAML file.
Args:
path: Path to YAML file to deserialize.
Returns:
Deserialized instance of `cls`.
Raises:
ValueError: If `data` is ill-formed.
NotImplementedError: If no deserialization method exists to deserialize.
FileNotFoundError: If file at `path` not found.
"""
path = check_path(
path,
"path",
must_exist=True,
must_be_file=True,
required_suffix=".yaml",
safe=True,
)
with path.open("r", encoding="utf-8") as file:
return Serializable._deserialize_class(cls, yaml.safe_load(file))
@staticmethod
def _deserialize_class(class_obj: type[Serializable], data: Any) -> Any:
"""Deserialize data into an instance of `class_obj`.
Returns:
Deserialized instance of type `class_obj`.
Raises:
ValueError: If deserialized instance is not of type `class_obj`.
"""
deserialized_obj = Serializable.deserialize(data)
if class_obj is not Serializable and not isinstance(
deserialized_obj, class_obj
):
msg = (
f"Deserialized object of type {type(deserialized_obj)},"
f" but expected {class_obj}"
)
raise ValueError(msg)
return deserialized_obj
[docs] @staticmethod
def deserialize(data: Any) -> Any:
"""Deserialize data.
Returns:
Deserialized object.
"""
# bool, str, None, int, float
if type(data) in [bool, str, type(None), int, float]:
return data
# list
if type(data) is list:
return [Serializable.deserialize(x) for x in data]
if type(data) is dict:
# tuple
tuple_data = data.pop("__tuple__", None)
if tuple_data is not None:
if type(tuple_data) is not list:
msg = f"Failed to deserialize tuple, got {type(tuple_data)}"
raise ValueError(msg)
return tuple(Serializable.deserialize(value) for value in tuple_data)
# dict
cls_path = data.pop("__class__", None)
if cls_path is None:
return {
key: Serializable.deserialize(value) for key, value in data.items()
}
cls = Serializable._class_from_path(cls_path)
# Serializable
if issubclass(cls, Serializable):
return cls._deserialize(data)
# External classes
if external := Serializable._get_external(cls):
return external[1](data)
msg = f"Could not deserialize class {cls_path}"
raise NotImplementedError(msg)
msg = f"Failed to deserialize type {type(data)}"
raise NotImplementedError(msg)
@classmethod
def _deserialize(cls, data: dict[str, Any]) -> Any:
"""Deserialize data into an instance of `cls`.
Classes derived from ``Serializable`` may override this method for custom
deserialization. In this case, override ``_serialize`` accordingly.
"""
data = {key: Serializable.deserialize(value) for key, value in data.items()}
return cls(**data)
[docs] @staticmethod
def register(
class_obj: type,
serialize: Callable[[Any], dict[str, Any]],
deserialize: Callable[[dict[str, Any]], Any],
) -> None:
"""Register serialization and deserialization functions for external class.
Args:
class_obj: Class to be serialized and deserialized.
serialize: Function that serializes instances of class `cls`.
deserialize: Function that deserializes into instance of class `cls`.
"""
cls_name = Serializable._class_to_path(class_obj)
if cls_name in _externals:
msg = f"Serialization functions for class {class_obj} already provided"
warnings.warn(msg, stacklevel=2)
_externals[cls_name] = (serialize, deserialize)
@staticmethod
def _get_external(
class_obj: type,
) -> tuple[Callable[[Any], dict[str, Any]], Callable[[dict[str, Any]], Any]] | None:
"""Get external serialization and deserialization functions for `cls`.
Returns ``None`` if they do not exist.
"""
cls_name = Serializable._class_to_path(class_obj)
if cls_name in _externals:
return _externals[cls_name]
return None
@staticmethod
def _class_to_path(class_obj: type) -> str:
"""Construct path of class."""
return f"{class_obj.__module__}.{class_obj.__name__}"
@staticmethod
def _class_from_path(cls_path: str) -> Any:
"""Obtain class from its path."""
module_name, class_name = cls_path.rsplit(".", 1)
if module_name == "":
msg = "Failed to deserialize because module name is empty"
raise ValueError(msg)
if module_name.startswith("."):
msg = f"Failed to deserialize because module name {module_name} is relative"
raise ValueError(msg)
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError as err:
msg = f"Failed to deserialize because could not import module {module_name}"
raise ModuleNotFoundError(msg) from err
if not hasattr(module, class_name):
msg = (
f"Failed to deserialize because class {class_name} "
f"was not found in module {module_name}"
)
raise ValueError(msg)
cls = getattr(module, class_name, None)
if not isinstance(cls, type):
msg = f"Failed to deserialize because {class_name} is not a class"
raise TypeError(msg)
return cls
if importlib.util.find_spec("numpy") is not None:
import numpy as np
from numpy.random import RandomState
from numpy.typing import NDArray
# Register `numpy.ndarray` as serializable
def _serialize_ndarray(value: NDArray[Any]) -> dict[str, str | list[Any]]:
return {"dtype": str(value.dtype), "array": value.tolist()}
def _deserialize_ndarray(data: dict[str, Any]) -> NDArray[Any]:
dtype = np.dtype(data["dtype"])
array = data["array"]
return np.array(array, dtype=dtype)
Serializable.register(np.ndarray, _serialize_ndarray, _deserialize_ndarray)
# Register `'numpy.random.mtrand.RandomState'` as serializable
def _serialize_random_state(value: RandomState) -> dict[str, str | list[Any]]:
return {"state": [Serializable.serialize(v) for v in value.get_state()]}
def _deserialize_random_state(data: dict[str, Any]) -> RandomState:
state = tuple(Serializable.deserialize(v) for v in data["state"])
rng = RandomState()
rng.set_state(state)
return rng
Serializable.register(
RandomState, _serialize_random_state, _deserialize_random_state
)
# Register `timedelta` as serializable
def _serialize_timedelta(time: timedelta) -> dict[str, float]:
return {"seconds": time.total_seconds()}
def _deserialize_timedelta(data: dict[str, float]) -> timedelta:
return timedelta(seconds=data["seconds"])
Serializable.register(timedelta, _serialize_timedelta, _deserialize_timedelta)
# Register `complex` as serializable
def _serialize_complex(value: complex) -> dict[str, float]:
return {"real": value.real, "imag": value.imag}
def _deserialize_complex(data: dict[str, float]) -> complex:
return complex(data["real"], data["imag"])
Serializable.register(complex, _serialize_complex, _deserialize_complex)
[docs]def check_serializable(serializable_object: Any) -> None:
"""Test if object is serializable and can be reconstructed from its serialization.
Args:
serializable_object: Object to be serialized and reconstructed.
Raises:
AssertionError: If the object is not Serializable, or if the reconstruction of
the object is not equal to the original object.
"""
# Test if object is Serializable
assert isinstance(serializable_object, Serializable), "Object is not Serializable" # noqa: S101
# Test to and from JSON
assert Serializable.from_json(serializable_object.to_json()) == serializable_object # noqa: S101
# Test to and from YAML
assert Serializable.from_yaml(serializable_object.to_yaml()) == serializable_object # noqa: S101
with tempfile.TemporaryDirectory() as temp_dir:
# Test to and from JSON file
temp_file_path = Path(temp_dir) / "test_file.json"
serializable_object.to_json_file(temp_file_path)
assert Serializable.from_json_file(temp_file_path) == serializable_object # noqa: S101
# Test to and from YAML file
temp_file_path = Path(temp_dir) / "test_file.yaml"
serializable_object.to_yaml_file(temp_file_path)
assert Serializable.from_yaml_file(temp_file_path) == serializable_object # noqa: S101