Source code for tno.quantum.ml.regression.linear_regression._estimator

"""Module with definitions of quantum-inspired estimators."""

from __future__ import annotations

import logging
import warnings
from typing import Any, Callable, SupportsFloat, SupportsInt

import numpy as np
from numpy import linalg as la
from numpy.typing import ArrayLike, NDArray

from tno.quantum.ml.components import SerializableEstimator
from tno.quantum.ml.regression.linear_regression._quantum_inspired import (
    compute_ls_probs,
    estimate_lambdas,
    sample_from_b,
    sample_from_x,
)
from tno.quantum.ml.regression.linear_regression._sketching import FKV, Halko, Sketcher
from tno.quantum.utils.serialization import Serializable
from tno.quantum.utils.validation import (
    check_arraylike,
    check_int,
    check_random_state,
    check_real,
)

logger = logging.getLogger(__name__)


def _serialize_sketcher(sketcher: Sketcher) -> dict[str, Any]:
    dict_ = {}
    for key, value in sketcher.__dict__.items():
        dict_[key] = Serializable.serialize(value)
    return dict_


def _deserialize_sketcher(
    data: dict[str, Any], sketcher_name: str = "fkv"
) -> FKV | Halko:
    data = {key: Serializable.deserialize(value) for key, value in data.items()}
    sketcher_class: type[FKV | Halko]
    if sketcher_name == "fkv":
        sketcher_class = FKV
    elif sketcher_name == "halko":
        sketcher_class = Halko
    else:
        message = '`sketcher_name` should be either "fkv" or "halko"'
        raise ValueError(message)
    sketcher = sketcher_class.__new__(sketcher_class)
    sketcher.__dict__.update(data)
    return sketcher


Serializable.register(
    FKV, _serialize_sketcher, lambda x: _deserialize_sketcher(x, sketcher_name="fkv")
)
Serializable.register(
    Halko,
    _serialize_sketcher,
    lambda x: _deserialize_sketcher(x, sketcher_name="halko"),
)


class EstimatorError(Exception):
    """Module exception."""

    def __init__(self, message: str) -> None:
        """Init :py:class:`EstimatorError`."""
        super().__init__(message)


[docs] class QILinearEstimator(SerializableEstimator): # noqa: PLW1641 """Quantum-inspired linear estimator."""
[docs] def __init__( # noqa: PLR0913 self, r: SupportsInt, c: SupportsInt, rank: SupportsInt, n_samples: SupportsInt, random_state: SupportsInt | None = None, sigma_threshold: SupportsFloat = 1e-15, sketcher_name: str = "fkv", func: Callable[[SupportsFloat], SupportsFloat] | None = None, ) -> None: """Init :py:class:`QILinearEstimator`. Args: r: number of rows to sample from `A`. c: number of columns to sample from `A`. rank: rank used to approximate matrix `A`. n_samples: number of samples to estimate inner products. Note: the sampling is performed from entries of `A`, so there are ``A.shape[0] * A.shape[1]`` possible entries. random_state: random seed. sigma_threshold: the argument `rank` is recomputed in case it is higher the number of singular values below this threhold. sketcher_name: name of sketching method: ``"fkv"`` or ``"halko"``. func: function to transform singular values when estimating lambda coefficients. This can be used for Tikhonov regularization purposes. """ self.r = r self.c = c self.rank = rank self.n_samples = n_samples self.random_state = random_state self.sigma_threshold = sigma_threshold self.sketcher_name = sketcher_name self.func = func
[docs] def fit( self, A: ArrayLike, b: ArrayLike, ) -> QILinearEstimator: """Fit data using quantum-inspired algorithm. Args: A: coefficient matrix `A`. b: vector `b`. """ # Validate input A = check_arraylike(A, "A", ndim=2) A = np.asarray(A, dtype=np.float64) b = check_arraylike(b, "b", ndim=1) b = np.asarray(b, dtype=np.float64) self.rank = check_int(self.rank, "rank", l_bound=1, l_inclusive=True) self.r = check_int(self.r, "r", l_bound=self.rank, l_inclusive=False) self.c = check_int(self.c, "c", l_bound=self.rank, l_inclusive=False) self.n_samples = check_int( self.n_samples, "n_samples ", l_bound=2, l_inclusive=True ) self.sigma_threshold = check_real( self.sigma_threshold, "sigma_threshold", l_bound=0, l_inclusive=False ) # Get random state rng = check_random_state(self.random_state, "random_state") # 1. Generate length-square probability distributions to sample from matrix `A` logger.info( "1. Generate length-square probability distributions to sample " "from matrix `A`" ) ( A_ls_prob_rows, A_ls_prob_columns_2d, A_ls_prob_columns, _, A_frobenius, ) = compute_ls_probs(A) # 2. Build matrix `C` logger.info("2. Build matrix `C`") self.sketcher_: Sketcher if self.sketcher_name == "fkv": self.sketcher_ = FKV( A, self.r, self.c, A_ls_prob_rows, A_ls_prob_columns_2d, A_frobenius, rng, ) elif self.sketcher_name == "halko": self.sketcher_ = Halko( A, self.r, self.c, A_ls_prob_columns, rng, ) else: message = '`sketcher_name` should be either "fkv" or "halko"' raise ValueError(message) C = self.sketcher_.right_project(self.sketcher_.left_project(A)) # 3. Compute the SVD of `C` logger.info("3. Compute the SVD of `C`") self.w_left_, self.sigma_, self.w_right_T_ = la.svd(C, full_matrices=False) # Recompute rank self.rank_ = self.rank rank_recomputed = int(np.count_nonzero(self.sigma_ > self.sigma_threshold)) if rank_recomputed < self.rank: message = f"Desired rank: {self.rank}; recomputed: {rank_recomputed}" warnings.warn(message, RuntimeWarning, stacklevel=2) logger.warning(message) self.rank_ = rank_recomputed # 4. Estimate lambda coefficients logger.info("4. Estimate lambda coefficients") func: Callable[[SupportsFloat], SupportsFloat] if self.func is None: def func_(arg: SupportsFloat) -> SupportsFloat: return arg func = func_ else: func = self.func self.lambdas_ = estimate_lambdas( A, b, self.n_samples, self.rank_, self.w_left_, self.sigma_, self.sketcher_, A_ls_prob_rows, A_ls_prob_columns_2d, A_frobenius, rng, func, ) return self
def _check_is_fitted(self) -> None: """Check if the `fit` method has been called.""" for attribute_name in [ "sketcher_", "w_left_", "sigma_", "w_right_T_", "rank_", "lambdas_", ]: if not hasattr(self, attribute_name): message = "Please call `fit` before making predictions" raise EstimatorError(message)
[docs] def sample_prediction_x( self, A: ArrayLike, n_entries_x: SupportsInt, ) -> tuple[NDArray[np.uint32], NDArray[np.float64]]: """Samples predictions of `x` using quantum-inspired model. Args: A: coefficient matrix `A`. n_entries_x: number of entries to be sampled from the solution vector `x`. Set this to 0 to skip this sampling step. Returns: Samples of predicted values and corresponding indices. """ self._check_is_fitted() rng = check_random_state(self.random_state, "random_state") A = check_arraylike(A, "A", ndim=2) A = np.asarray(A, dtype=np.float64) n_entries_x = check_int(n_entries_x, "n_entries_x", l_bound=1, l_inclusive=True) logger.info("Sample predicted `x`") # Compute `omega` omega = self.w_left_[:, : self.rank_] @ ( self.lambdas_ / self.sigma_[: self.rank_] ) omega_norm = float(la.norm(omega)) # Sample entries of solution vector `x` sampled_indices_x = np.zeros(n_entries_x, dtype=np.uint32) sampled_x = np.zeros(n_entries_x) for t in range(n_entries_x): sampled_indices_x[t], sampled_x[t] = sample_from_x( A, self.sketcher_, omega, omega_norm, rng, ) if (t + 1) % 100 == 0: logger.info("---%s entries sampled out of %s", t + 1, n_entries_x) return sampled_indices_x, sampled_x
[docs] def sample_prediction_b( self, A: ArrayLike, n_entries_b: SupportsInt, ) -> tuple[NDArray[np.uint32], NDArray[np.float64]]: """Sample predictions of `b` using quantum-inspired model. Args: A: coefficient matrix `A`. n_entries_b: number of entries to be sampled from the predicted `b`. Returns: Samples of predicted values and corresponding indices. """ self._check_is_fitted() rng = check_random_state(self.random_state, "random_state") A = check_arraylike(A, "A", ndim=2) A = np.asarray(A, dtype=np.float64) n_entries_b = check_int(n_entries_b, "n_entries_b", l_bound=1, l_inclusive=True) logger.info("Sample predicted `b`") # Compute `phi` phi = self.w_right_T_.T[:, : self.rank_] @ self.lambdas_ phi_norm = float(la.norm(phi)) # Sample entries of `b` sampled_indices_b = np.zeros(n_entries_b, dtype=np.uint32) sampled_b = np.zeros(n_entries_b) for t in range(n_entries_b): sampled_indices_b[t], sampled_b[t] = sample_from_b( A, self.sketcher_, phi, phi_norm, rng, ) if (t + 1) % 100 == 0: logger.info("---%s entries sampled out of %s", t + 1, n_entries_b) return sampled_indices_b, sampled_b
[docs] def __eq__(self, other: object) -> bool: """Check for equality for serialization purposes.""" if not isinstance(other, QILinearEstimator): return NotImplemented # Get the dictionaries of attributes self_attrs = self.__dict__ other_attrs = other.__dict__ # Check the keys if set(self_attrs.keys()) != set(other_attrs.keys()): return False # Iterate and compare values for key, self_value in self_attrs.items(): other_value = other_attrs[key] # Check if they are numpy arrays and compare accordingly is_self_np = isinstance(self_value, np.ndarray) is_other_np = isinstance(other_value, np.ndarray) if is_self_np and is_other_np: if not np.array_equal(self_value, other_value): return False elif is_self_np != is_other_np or self_value != other_value: return False return True