Source code for tno.quantum.problems.mot.dataset._visualization

from typing import Any, cast

import cv2
import imageio
import matplotlib as mpl
import numpy as np
from cv2.typing import MatLike
from numpy.typing import NDArray

from tno.quantum.problems.mot import DetectionSequence

mpl.use("Agg")
MAX_COLORS = 20


def get_color_palettes(
    cmap_name: str = "hsv", num: int = MAX_COLORS
) -> NDArray[np.float64]:
    """Gets a list of colors from a given color map.

    Args:
        cmap_name: color map name based on `matplotlib.colormaps`
        num: number of colors

    Returns:
        List of colors from the color map.

    """
    rng = np.random.default_rng(seed=0)
    cmap = mpl.colormaps.get_cmap(cmap_name)
    p = np.linspace(0, 1.0, num)
    rng.shuffle(p)
    return cmap(p)


def draw_detection(
    frame: MatLike,
    box: tuple[float, float, float, float],
    label: str,
    color: tuple[int],
) -> MatLike:
    """Draws the detection box and its label in the specified color.

    Args:
        frame: the frame to draw on
        box: box coordinates in bottom-left/top-right format
        label: label text
        color: color of the box

    Returns:
        The frame with the drawn box
    """
    height = frame.shape[0]
    width = frame.shape[1]

    x1, y1, x2, y2 = box
    x1 = int(x1 * width)
    y1 = int(y1 * height)
    x2 = int(x2 * width)
    y2 = int(y2 * height)

    # draw detection box
    frame = cv2.rectangle(
        frame,
        (x1, y1),
        (x2, y2),
        color,
        thickness=7,
    )

    # draw label
    font_face = cv2.FONT_HERSHEY_SIMPLEX
    scale = 1
    thickness = 4
    frame = cv2.rectangle(frame, (x1, y1), (x1 + 50, y1 - 40), color, thickness=-1)
    return cv2.putText(
        frame,
        label if len(label) == 2 else " " + label,
        (x1, y1 - 10),
        font_face,
        scale,
        (244, 244, 244),
        thickness,
        cv2.LINE_AA,
    )


[docs] def make_video_with_detections( detections: DetectionSequence, input_filename: str, output_filename: str, fps: int = 50, first_frame: int = 0, ) -> None: """Draws the detection boxes over the frames of the original videos. Each ID is shown in a different color. If the detection sequence does not span the whole input video, the first and last frames need to be specified. Args: detections: the detection sequence input_filename: name of the input video file output_filename: name of the output video file fps: frames per second of the output video first_frame: frame number corresponding to detections[0]. """ colors = get_color_palettes() input_video = cv2.VideoCapture(input_filename) width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) output_video = cv2.VideoWriter(output_filename, 0, fps, (width, height)) i = -1 while input_video.isOpened(): i += 1 ret, frame = input_video.read() if (not ret) or i >= first_frame + len(detections): break if i < first_frame: continue framedets = detections[i - first_frame] ids = [int(_id) for _id in framedets.ids] boxes = [framedets.xyxy[k] for k in range(framedets.size)] for _id, box in zip(ids, boxes, strict=True): if _id == 0: continue color = 244 * colors[_id % MAX_COLORS] frame = draw_detection(frame, box, str(_id), color) output_video.write(frame) output_video.release() input_video.release() cv2.destroyAllWindows()
[docs] def make_gif_with_detections( detections: DetectionSequence, input_filename: str, output_filename: str, fps: int = 50, first_frame: int = 0, ) -> None: """Draws the detection boxes over the frames of the original videos. Each ID is shown in a different color. If the detection sequence does not span the whole input video, the first and last frames need to be specified. Args: detections: the detection sequence input_filename: name of the input video file output_filename: name of the output gif fps: frames per second of the output gif first_frame: frame number corresponding to detections[0]. """ colors = get_color_palettes() input_video = cv2.VideoCapture(input_filename) i = -1 frames = [] while input_video.isOpened(): i += 1 ret, frame = input_video.read() if (not ret) or i >= first_frame + len(detections): break if i < first_frame: continue framedets = detections[i - first_frame] ids = [int(_id) for _id in framedets.ids] boxes = [framedets.xyxy[k] for k in range(framedets.size)] for _id, box in zip(ids, boxes, strict=True): if _id == 0: continue color = 244 * colors[_id % MAX_COLORS] frame = draw_detection(frame, box, str(_id), color) frames.append(frame) input_video.release() imageio.mimsave(output_filename, cast("list[Any]", frames), fps=fps) cv2.destroyAllWindows()