Source code for tno.quantum.communication.qkd_key_rate.classical._permutations

"""Base class for Permutations object."""

from __future__ import annotations

import numpy as np
from numpy.random import RandomState

from tno.quantum.utils.validation import check_random_state


[docs]class Permutations: """Permutations object."""
[docs] def __init__(self, permutations: list[list[int]]) -> None: """Init Permutations. Args: permutations: list of permutations """ self.permutations = permutations self.inverted_permutations = self.calculate_inverted_permutations(permutations) self.number_of_passes = len(permutations)
[docs] def __add__(self, other: Permutations) -> Permutations: """Combine two permutations.""" return Permutations(self.permutations + other.permutations)
[docs] def __len__(self) -> int: """Return number of passes.""" return self.number_of_passes
[docs] def __eq__(self, other: object) -> bool: """Permutation equality.""" if not isinstance(other, Permutations): return NotImplemented return self.permutations == other.permutations
[docs] def __getitem__(self, pass_number: int) -> list[int]: """Return permutation for specific pass. Args: pass_number: index of the pass number Returns: The permutation corresponding to pass number. """ return self.permutations[pass_number]
[docs] def shorten_pass(self, pass_idx: int, max_length: int) -> None: """Shorten message length. In some protocols, the message length decreases, this function adjusts the permutations accordingly, by discarding large indices. Args: pass_idx: Index of the permutation to shorten max_length: New message length, and maximum of the entries in the permutation (value itself not included) """ self.permutations[pass_idx] = [ i for i in self.permutations[pass_idx] if i < max_length ] self.inverted_permutations[pass_idx] = list( np.argsort(self.permutations[pass_idx]) )
[docs] @staticmethod def calculate_inverted_permutations( permutations: list[list[int]], ) -> list[list[int]]: """Invert every permutation in a list of permutations.""" return [list(np.argsort(permutation)) for permutation in permutations]
[docs] @classmethod def random_permutation( cls, number_of_passes: int, message_size: int, random_state: int | RandomState | None = None, ) -> Permutations: """Generate a random Permutations object. Args: number_of_passes: Total number of iterations message_size: Size of the message for which to calculate the parity strategy random_state: Random state for reproducibility. Defaults to ``None``. """ random_state = check_random_state(random_state, "random_state") permutations = [ random_state.permutation(message_size).tolist() for _ in range(number_of_passes) ] return cls(permutations)