Source code for tno.quantum.ml.datasets._utils

"""Some basic utility functions for datasets."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike, NDArray
from sklearn.model_selection import train_test_split
from tno.quantum.utils.validation import check_arraylike, check_int, check_path

if TYPE_CHECKING:
    from pathlib import Path


def _pre_process_data(
    X: NDArray[Any], y: NDArray[Any], n_features: int, n_classes: int
) -> tuple[NDArray[Any], NDArray[Any]]:
    """Slice `X` and `y` and cast the dtype of `y` to the dtype of `X`.

    Args:
        X: Feature matrix to slice.
        y: Target samples to slice and cast.
        n_features: Number of features.
        n_classes: Number of classes

    Returns:
        ``Tuple`` of `X` and `y`, where `X`, and `y` are sliced to have the correct
        number of classes and number of features. Furthermore, the datatype of `y` is
        set to the datatype of `X`.
    """
    # Validate input
    n_features = check_int(n_features, name="n_features", l_bound=1)
    n_classes = check_int(n_classes, name="n_classes", l_bound=1)

    y = y.astype(X.dtype)
    ind = y < n_classes
    X = X[ind, :n_features]
    y = y[ind]
    return X, y


def _safe_train_test_split(
    X: ArrayLike,
    y: ArrayLike,
    test_size: int | float | None = 0.25,
    random_state: int | np.random.RandomState | None = None,
) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any], NDArray[Any]]:
    """Wrapper around sklearn's `train_test_split` that allows `test_size=0`.

    If test_size == 0, returns all of (X, y) as training data and empty arrays for
    the validation data.

    Args:
        X: Dataset samples.
        y: Dataset labels.
        test_size: Size of the test dataset.
        random_state: random state for reproducibility.

    Returns:
        A tuple containing ``X_training``, ``y_training``, ``X_validation`` and
        ``y_validation`` of the dataset.
    """
    X = np.asarray(X)
    y = np.asarray(y)

    if test_size == 0:
        return (
            X,
            y,
            np.empty((0, *X.shape[1:]), dtype=X.dtype),
            np.empty((0, *y.shape[1:]), dtype=y.dtype),
        )

    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )
    return X_train, y_train, X_val, y_val


[docs] def show_images_dataset( images: ArrayLike, labels: ArrayLike | None = None, save_path: None | Path = None, max_per_row: int = 4, *, show: bool = True, ) -> None: r"""Displays a grayscale image with a grid and pixel values. Example usage: >>> from tno.quantum.ml.datasets import get_bars_and_stripes_dataset, show_images_dataset >>> X, y, _, _ = get_bars_and_stripes_dataset(6, (8, 8), noise_size=30, test_size=0) >>> show_images_dataset(X, y, max_per_row=3) # doctest: +SKIP Args: images: Greyscale images to plot labels: Corresponding labels of each image. save_path: if provided, the path where figure is saved to. max_per_row: Maximum number of images in a single row. show: whether to call plt.show() at the end. Raises: ValueError: If labels is provided but has not same length as images. """ # noqa: E501 images = check_arraylike(images, "images", ndim=3) labels = check_arraylike(labels, "labels", ndim=1) max_per_row = check_int(max_per_row, "max_per_row", l_bound=1) number_of_images = len(images) if labels is not None and len(labels) != number_of_images: error_msg = ( f"labels length {len(labels)} does not match " f"number of images {number_of_images}." ) raise ValueError(error_msg) # Compute rows and columns ncols = min(number_of_images, max_per_row) nrows = number_of_images // max_per_row if number_of_images % max_per_row: nrows += 1 _, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows), squeeze=False) axes = axes.ravel() for idx, (ax, img) in enumerate(zip(axes, images, strict=False)): rows, cols = img.shape ax.imshow(img, cmap="gray", vmin=0, vmax=255) # Draw grid ax.set_xticks(np.arange(-0.5, cols, 1), minor=True) ax.set_yticks(np.arange(-0.5, rows, 1), minor=True) ax.grid(which="minor", color="red", linestyle="-", linewidth=0.5) # Remove ticks ax.tick_params( which="both", bottom=False, left=False, labelbottom=False, labelleft=False ) # Add label as title if labels is not None: ax.set_title(f"label={labels[idx]}") for ax in axes[number_of_images:]: ax.axis("off") if save_path is not None: save_path = check_path(save_path, "save_path") plt.savefig(save_path, bbox_inches="tight") if show: plt.show()