Source code for tno.quantum.ml.components._serialization

"""Class for serialization of sklearn compatible estimators."""

from __future__ import annotations

import sys
from typing import Any

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

from tno.quantum.utils.serialization import Serializable, check_serializable


[docs] class SerializableEstimator(Serializable): """Framework for serializable estimators.""" @override def _serialize(self) -> dict[str, Any]: """Serialize to dict for sklearn estimators.""" # Serialize as usual to store constructor arguments dict_ = Serializable._serialize(self) # noqa: SLF001 # Store derived attributes as well (i.e. those with trailing underscore) for attr, value in self.__dict__.items(): if attr.endswith("_"): dict_[attr] = Serializable.serialize(value) return dict_ @override @classmethod def _deserialize(cls, data: dict[str, Any]) -> Any: """Deserialize to dict for sklearn estimators.""" # Split attributes based on whether they have a trailing underscore attrs_init = [attr for attr in data if not attr.endswith("_")] attrs_other = [attr for attr in data if attr.endswith("_")] # Initialize with attributes without trailing underscores instance = cls( **{attr: Serializable.deserialize(data[attr]) for attr in attrs_init} ) # Set additional attributes for attributes with trailing underscore for attr in attrs_other: setattr(instance, attr, Serializable.deserialize(data[attr])) return instance
[docs] def check_estimator_serializable(serializable_object: Any) -> None: """Test if object is serializable and can be reconstructed from its serialization. Args: serializable_object: Object to be serialized and reconstructed. Raises: AssertionError: If the object is not SerializableEstimator, or if the reconstruction of the object is not equal to the original object. """ # Test if object is SerializableEstimator assert isinstance(serializable_object, SerializableEstimator), ( # noqa: S101 "Object is not SerializableEstimator" ) check_serializable(serializable_object)