Source code for qkd_key_rate.base.permutations

"""Base class for Permutations object."""
from __future__ import annotations

from typing import List

import numpy as np


[docs]class Permutations: """Permutations object""" def __init__(self, permutations: List[List[int]]): self.permutations = permutations self.inverted_permutations = self.calculate_inverted_permutations(permutations) self.number_of_passes = len(permutations)
[docs] def __add__(self, other: Permutations) -> Permutations: """Combine two permutations""" return Permutations(self.permutations + other.permutations)
[docs] def __len__(self) -> int: "Return number of passes" return self.number_of_passes
def __eq__(self, other: object) -> bool: if not isinstance(other, Permutations): return NotImplemented return self.permutations == other.permutations def __getitem__(self, pass_number: int) -> List[int]: return self.permutations[pass_number]
[docs] def shorten_pass(self, pass_idx: int, max_length: int) -> None: """In some protocols, the message length decreases, this function adjusts the permutations accordingly, by discarding large indices. Args: pass_idx: Index of the permutation to shorten max_length: New message length, and maximum of the entries in the permutation (value itself not included) """ self.permutations[pass_idx] = [ i for i in self.permutations[pass_idx] if i < max_length ] self.inverted_permutations[pass_idx] = list( np.argsort(self.permutations[pass_idx]) )
[docs] @staticmethod def calculate_inverted_permutations( permutations: List[List[int]], ) -> List[List[int]]: """Invert every permutation in a list of permutations""" return [list(np.argsort(permutation)) for permutation in permutations]
[docs] @classmethod def random_permutation( cls, number_of_passes: int, message_size: int ) -> Permutations: """Generate a random Permutations object Args: number_of_passes: Total number of iterations message_size: Size of the message for which to calculate the parity strategy """ permutations = [ list(np.random.permutation(message_size)) for _ in range(number_of_passes) ] return cls(permutations)