Source code for tno.quantum.communication.qkd_key_rate.classical._message

"""Base class for Message object."""

from __future__ import annotations

from collections.abc import Iterable, Iterator
from dataclasses import dataclass

from numpy.random import RandomState

from tno.quantum.utils.validation import check_binary, check_random_state, check_real


[docs]@dataclass(init=False) class Message: """Message object containing binary bits.""" message: list[int] """Message object."""
[docs] def __init__(self, message: Iterable[int | str]) -> None: """Init :py:class:`Message`. Args: message: The message, iterable object with binary items. Raises: TypeError: If `message` contains items that can't be converted to a binary. ValueError: If `message` contains items that can't be converted to a binary. """ self.message = [ check_binary(value, f"message[{i}]") for i, value in enumerate(message) ]
@property def length(self) -> int: """Length of message.""" return len(self.message)
[docs] def __getitem__(self, key: int) -> int: """Return value of message for specific index. Args: key: The index at which the value should be returned. """ return self.message[key]
[docs] def __setitem__(self, key: int, value: int) -> None: """Set key of message to specific value. Args: key: The index at which the value should be set. value: The value to be inserted at the specified index. """ self.message[key] = value
[docs] def __bytes__(self) -> bytes: """Bytes representation of message.""" return bytes("".join(str(x) for x in self.message).encode("utf-8"))
[docs] def __str__(self) -> str: """String representation of message.""" res = "".join(str(i) for i in self.message)[:50] if self.length > 50: res += "..." return res
[docs] def pop(self, index: int = -1) -> int: """Remove bit at a specific index from message.""" return self.message.pop(index)
[docs] def apply_permutation(self, permutation: list[int]) -> None: """Apply a permutation to the message. Args: permutation: The permutation that is applied Raises: ValueError: If message is incompatible with permutation. """ if self.length != len(permutation): error_msg = "Message is incompatible with permutation." raise ValueError(error_msg) self.message = [self.message[i] for i in permutation]
[docs] @classmethod def random_message( cls, message_length: int, random_state: int | RandomState | None = None ) -> Message: """Generate a random message. Args: message_length: Length of random message random_state: Random state for reproducibility. Defaults to ``None``. Returns: random message """ random_state = check_random_state(random_state, "random_state") return cls(list(random_state.randint(2, size=message_length)))
[docs] def __iter__(self) -> Iterator[int]: """Create iterator for message bits.""" return iter(self.message)
[docs] def apply_errors( self, error_rate: float, random_state: int | RandomState | None = None ) -> Message: """Apply errors to message. Args: error_rate: probability that an error occurs. random_state: Random state for reproducibility. Defaults to ``None`` Returns: Message to which errors are applied. Raises: ValueError if error rate not provided as percentage. """ error_rate = check_real(error_rate, "error_rate", l_bound=0, u_bound=1) random_state = check_random_state(random_state, "random_state") error_message = [x if random_state.rand() > error_rate else 1 - x for x in self] return Message(error_message)