Source code for qkd_key_rate.protocols.classical.winnow

"""Classes to perform a Winnow error correction protocol.

The Winnow error correction protocol is based on Hamming codes. An advantage of the protocol
is that it requires less communication than other error correction protocols. The protocol
however might introduce errors in specific cases. With every communication, the Winnow protocol
leaks information to potential eavesdroppers. This can be overcome by discarding message bits
equal to the amount of information leaks, thereby achieving privacy maintenance.

Typical usage example:

    .. code-block:: python

        import numpy as np

        from tno.quantum.communication.qkd_key_rate.base import Message, Permutations, Schedule
        from tno.quantum.communication.qkd_key_rate.protocols.classical.winnow import (
            WinnowCorrector,
            WinnowReceiver,
            WinnowSender,
        )

        error_rate = 0.05
        message_length = 10000
        input_message = Message.random_message(message_length=message_length)
        error_message = Message(
            [x if np.random.rand() > error_rate else 1 - x for x in input_message]
        )
        schedule = Schedule.schedule_from_error_rate(error_rate=error_rate)
        number_of_passes = np.sum(schedule.schedule)
        permutations = Permutations.random_permutation(
            number_of_passes=number_of_passes, message_size=message_length
        )

        alice = WinnowSender(
            message=input_message, permutations=permutations, schedule=schedule
        )
        bob = WinnowReceiver(
            message=error_message, permutations=permutations, schedule=schedule
        )
        corrector = WinnowCorrector(alice=alice, bob=bob)
        summary = corrector.correct_errors()
"""
from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional

import numpy as np

from tno.quantum.communication.qkd_key_rate.base import (
    Corrector,
    CorrectorOutputBase,
    Message,
    Permutations,
    ReceiverBase,
    Schedule,
    SenderBase,
)


[docs]class WinnowSender(SenderBase): """This class encodes all functions available to both sender and receiver. It keeps track of the number of exposed bits and can compute syndromes and parities. Furthermore, it keeps track of the blocks with errors """
[docs] def __init__( self, message: Message, permutations: Permutations, schedule: Schedule, name: Optional[str] = None, ) -> None: """ Args: message: Input message of the sender party permutations: List containing permutations for each pass name: Name of the sender party """ super().__init__(message=message, permutations=permutations, name=name) self.schedule = schedule self.number_of_bad_blocks = 0 self.block_size = 8 self.number_of_blocks = 0 self.maximum_number_of_communication_rounds = 0 self.number_of_exposed_bits = 0 self.number_of_passes = permutations.number_of_passes # In Winnow, some bits are discarded. The net_exposed_bits is # number of exposed bits, minus number of discarded bits self.net_exposed_bits = 0 self.syndrome_length = 3 self.removed_bits: List[int] = [] self.syndrome_array: List[int] = [ 0 for _ in range(int(np.floor(self.message.length / 8))) ] # Stores the indices of blocks which contain errors self.bad_blocks_array: List[int] = [ 0 for _ in range(int(np.floor(self.message.length / 8))) ] self.parity_string: List[List[int]] = [[]] * self.number_of_passes self.parity_check_matrix = np.zeros([10, 1023], dtype=int) self.transcript = "" self.first_pass()
[docs] def create_parity_check_matrix(self) -> None: """Creates a parity check matrix. This matrix is used to encode the bit strings. """ size = 1 << self.syndrome_length for i in range(self.syndrome_length): for j in range(1, size): self.parity_check_matrix[i, j - 1] = int(j / (1 << i)) & 0x1
[docs] def get_parity(self, index_start: int, index_end: int) -> int: """Get the parity of a specific message part between two indices. Args: index_start: Start index of the message index_end: End index of the message Returns: Parity of substring """ number_of_ones = 0 for i in range(index_start, index_end): if self.message[i] == 1: number_of_ones += 1 return number_of_ones % 2
[docs] def build_parity_string(self) -> None: """Builds a parity string for all blocks.""" for i in range(self.number_of_blocks): index_start = i * self.block_size index_end = index_start + self.block_size parity = self.get_parity(index_start, index_end) self.parity_string[i] = parity self.number_of_exposed_bits += 1 self.net_exposed_bits += 1
[docs] def discard_parity_bits(self) -> None: """The first bit of every parity block is discarded.""" old_index = 0 counter = -1 while old_index < self.message.length: if ((old_index % self.block_size) == 0) and ( old_index != (self.number_of_blocks * self.block_size) ): # Is it the first bit of a block and is it not the last block counter += 1 self.net_exposed_bits -= 1 self.message.pop(old_index - counter) old_index += 1 self.block_size -= 1 self.transcript += ( f"\tBoth discard the parity bits for pass {self.schedule.pass_number}.\n" )
[docs] def get_syndrome(self, index_block: int) -> int: """Computes the syndrome of a block. Both parties compute their syndrome individually, hence, no communication is needed here. """ if index_block > self.number_of_blocks: print("Illegal block number. Returning block_size + 1 for new syndrome.\n") return self.block_size + 1 placeholder = 0 new_syndrome = 0 # Computer he highest order bit of the syndrome first and then work down for i in range(self.syndrome_length - 1, -1, -1): new_syndrome <<= 1 # Multiply the block by the (i-1)-th row of the parity check matrix # and add it to the syndrome for j in range(self.block_size): placeholder += ( self.parity_check_matrix[i, j] * self.message[index_block * self.block_size + j] ) placeholder &= 0x1 new_syndrome += placeholder placeholder = 0 self.number_of_exposed_bits += self.syndrome_length self.net_exposed_bits += self.syndrome_length return new_syndrome
[docs] def disagreeing_block_parities(self, alice: WinnowSender) -> None: """Finds the disagreeing block parities. The found parities of both parties are compared. This can be done with two communication rounds (one both ways). Afterwards, both separately process the results. Args: alice: The sending party """ self.maximum_number_of_communication_rounds += 1 self.transcript += ( f"Bob sends his parity string to alice for pass {self.schedule.pass_number}" ". Alice compares the string with hers and keeps track of the disagreeing" " blocks.\n" ) counter_for_bad_blocks = 0 for i in range(self.number_of_blocks): if alice.parity_string[i] != self.parity_string[i]: # The parities disagree, save the block-index as bad block. self.bad_blocks_array[counter_for_bad_blocks] = i alice.bad_blocks_array[counter_for_bad_blocks] = i counter_for_bad_blocks += 1 self.number_of_bad_blocks = counter_for_bad_blocks alice.number_of_bad_blocks = counter_for_bad_blocks
[docs] def discard_syndrome_bits(self) -> None: """Discards syndrome bits. Bits at indices $2^j-1$ are removed. These correspond to the linearly independent columns of the parity check matrix. In this function the number of bad blocks is known. No communication is needed to discard syndrome bits. """ counter_for_error_blocks = 0 removed_bits = [] for index_block in range(self.number_of_blocks): # If we have not hit all bad blocks, and the counter is at a bad block if ( counter_for_error_blocks < self.number_of_bad_blocks and self.bad_blocks_array[counter_for_error_blocks] == index_block ): power = 0 counter_for_error_blocks += 1 offset_counter = -1 for index_bit in range(self.block_size): if (index_bit + 1) == (1 << power): # Discard bits if they are at a location with index a # power of 2 - 1 power += 1 offset_counter += 1 self.net_exposed_bits -= 1 removed_bits.append(index_bit - offset_counter) self.message.pop(index_bit - offset_counter) self.transcript += ( f"\tBoth discard the syndrome bits for pass {self.schedule.pass_number}.\n" )
[docs] def build_syndrome_string(self, alice: WinnowSender) -> None: """Create a syndrome string for all disagreeing blocks. Computes the syndrome for blocks with disagreeing parity. Args: alice: The sending party """ self.disagreeing_block_parities(alice) for i in range(self.number_of_bad_blocks): alice.syndrome_array[i] = self.get_syndrome(alice.bad_blocks_array[i])
[docs] def first_pass(self) -> None: """First pass with initializations and parity determination.""" # Note that the incomplete last block is not included in each pass self.number_of_blocks = int(np.floor(self.message.length / self.block_size)) self.parity_string = [0] * self.number_of_blocks self.create_parity_check_matrix() self.permute_buffer() self.build_parity_string() self.discard_parity_bits()
[docs] def next_pass(self) -> None: """Performs the necessary computations to prepare for the next pass, including permuting the message and creating a new parity string.""" i = self.schedule.next_pass() self.syndrome_length = i + 3 self.block_size = 1 << self.syndrome_length # Note that the incomplete last block is not included in each pass self.number_of_blocks = int(np.floor(self.message.length / self.block_size)) self.create_parity_check_matrix() self.permute_buffer() self.build_parity_string() self.discard_parity_bits()
[docs] def permute_buffer(self) -> None: """Permutes the message string.""" self.permutations.shorten_pass(self.schedule.pass_number, self.message.length) self.message.apply_permutation(self.permutations[self.schedule.pass_number])
[docs]class WinnowReceiver(WinnowSender, ReceiverBase): """This class encodes all functions only available to the receiver. The receiver is assumed to have a string with errors and is thus assumed to correct the errors. """
[docs] def __init__( self, message: Message, permutations: Permutations, schedule: Schedule, name: Optional[str] = None, ) -> None: """ Args: message: Input message of the sender party permutations: List containing permutations for each pass name: Name of the receiver party """ super().__init__( message, name=name, permutations=permutations, schedule=schedule )
[docs] def fix_errors_with_syndrome(self, alice: WinnowSender) -> None: """Corrects errors using the syndrome strings of alice and bob Args: alice: The sending party """ syndrome_alice = alice.syndrome_array for i in range(alice.number_of_bad_blocks): syndrome_bob = alice.get_syndrome(alice.bad_blocks_array[i]) disagreeing_syndrome_bit = ( syndrome_alice[i] ^ syndrome_bob ) # XOR the two syndromes if disagreeing_syndrome_bit == 0: # Erroneous bit was already discarded continue self.correct_individual_error( alice.bad_blocks_array[i] * self.block_size + disagreeing_syndrome_bit - 1 ) self.transcript += "bob computes the syndromes for his bit string and \ accordingly corrects bits in his own bit string based on the difference with \ the syndromes of alice.\n" self.discard_syndrome_bits() alice.discard_syndrome_bits()
[docs] def correct_errors(self, alice: WinnowSender) -> None: """The main routine, finds all errors and corrects them. It is assumed that Alice and Bob use one communication round to agree on the used permutations. Afterwards, they use two communication rounds per iteration to communicate the syndromes. Args: alice: The sending party """ assert isinstance(alice, WinnowSender) alice.schedule = deepcopy(self.schedule) self.maximum_number_of_communication_rounds += 1 if (len(self.schedule) - len(self.permutations)) >= 0: # Add permutations, if there are not enough for the whole schedule self.permutations += Permutations.random_permutation( number_of_passes=len(self.schedule) - len(self.permutations) + 1, message_size=self.message.length, ) self.transcript += ( "Permutations for bit string shared between alice and bob.\n" ) alice.permutations = deepcopy(self.permutations) number_of_remaining_passes = 1 while number_of_remaining_passes > 0: self.maximum_number_of_communication_rounds += 1 self.transcript += ( "alice sends the syndromes for the disagreeing blocks of pass" f"{self.schedule.pass_number} to bob, as well as the block indices." "If there are no disagreeing blocks, both enter the next pass.\n" ) self.build_syndrome_string(alice) # If we have blocks with disagreeing syndromes, correct bits accordingly if self.number_of_bad_blocks != 0: self.fix_errors_with_syndrome(alice) else: self.transcript += ( f"\tNo disagreeing parities in pass {self.schedule.pass_number}" ", continue.\n" ) self.next_pass() alice.next_pass() number_of_remaining_passes = self.schedule.remaining_passes
[docs]@dataclass class WinnowCorrectorOutput(CorrectorOutputBase): """Data class for Winnow Corrector output""" schedule: List[int]
[docs]class WinnowCorrector(Corrector): """ Winnow corrector """ def __init__(self, alice: WinnowSender, bob: WinnowReceiver): super().__init__(alice=alice, bob=bob) assert self.alice.permutations == self.bob.permutations
[docs] def summary(self) -> WinnowCorrectorOutput: """ Calculate a summary object for the error correction - original message - corrected message - error rate (before and after correction) - number_of_exposed_bits - key_reconciliation_rate - protocol specific parameters """ corrector_output = WinnowCorrectorOutput( input_alice=self.alice.original_message, output_alice=self.alice.message, input_bob=self.bob.original_message, output_bob=self.bob.message, input_error=self.calculate_error_rate( self.alice.original_message, self.bob.original_message ), output_error=self.calculate_error_rate( self.alice.message, self.bob.message ), output_length=self.alice.message.length, number_of_exposed_bits=self.bob.net_exposed_bits, key_reconciliation_rate=self.calculate_key_reconciliation_rate( exposed_bits=True ), number_of_communication_rounds=self.bob.maximum_number_of_communication_rounds, schedule=self.bob.schedule, ) return corrector_output