from typing import Any, cast
import cv2
import imageio
import matplotlib as mpl
import numpy as np
from cv2.typing import MatLike
from numpy.typing import NDArray
from tno.quantum.problems.mot import DetectionSequence
mpl.use("Agg")
MAX_COLORS = 20
def get_color_palettes(
cmap_name: str = "hsv", num: int = MAX_COLORS
) -> NDArray[np.float64]:
"""Gets a list of colors from a given color map.
Args:
cmap_name: color map name based on `matplotlib.colormaps`
num: number of colors
Returns:
List of colors from the color map.
"""
rng = np.random.default_rng(seed=0)
cmap = mpl.colormaps.get_cmap(cmap_name)
p = np.linspace(0, 1.0, num)
rng.shuffle(p)
return cmap(p)
def draw_detection(
frame: MatLike,
box: tuple[float, float, float, float],
label: str,
color: tuple[int],
) -> MatLike:
"""Draws the detection box and its label in the specified color.
Args:
frame: the frame to draw on
box: box coordinates in bottom-left/top-right format
label: label text
color: color of the box
Returns:
The frame with the drawn box
"""
height = frame.shape[0]
width = frame.shape[1]
x1, y1, x2, y2 = box
x1 = int(x1 * width)
y1 = int(y1 * height)
x2 = int(x2 * width)
y2 = int(y2 * height)
# draw detection box
frame = cv2.rectangle(
frame,
(x1, y1),
(x2, y2),
color,
thickness=7,
)
# draw label
font_face = cv2.FONT_HERSHEY_SIMPLEX
scale = 1
thickness = 4
frame = cv2.rectangle(frame, (x1, y1), (x1 + 50, y1 - 40), color, thickness=-1)
return cv2.putText(
frame,
label if len(label) == 2 else " " + label,
(x1, y1 - 10),
font_face,
scale,
(244, 244, 244),
thickness,
cv2.LINE_AA,
)
[docs]
def make_video_with_detections(
detections: DetectionSequence,
input_filename: str,
output_filename: str,
fps: int = 50,
first_frame: int = 0,
) -> None:
"""Draws the detection boxes over the frames of the original videos.
Each ID is shown in a different color.
If the detection sequence does not span the whole input video, the first and last
frames need to be specified.
Args:
detections: the detection sequence
input_filename: name of the input video file
output_filename: name of the output video file
fps: frames per second of the output video
first_frame: frame number corresponding to detections[0].
"""
colors = get_color_palettes()
input_video = cv2.VideoCapture(input_filename)
width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_video = cv2.VideoWriter(output_filename, 0, fps, (width, height))
i = -1
while input_video.isOpened():
i += 1
ret, frame = input_video.read()
if (not ret) or i >= first_frame + len(detections):
break
if i < first_frame:
continue
framedets = detections[i - first_frame]
ids = [int(_id) for _id in framedets.ids]
boxes = [framedets.xyxy[k] for k in range(framedets.size)]
for _id, box in zip(ids, boxes, strict=True):
if _id == 0:
continue
color = 244 * colors[_id % MAX_COLORS]
frame = draw_detection(frame, box, str(_id), color)
output_video.write(frame)
output_video.release()
input_video.release()
cv2.destroyAllWindows()
[docs]
def make_gif_with_detections(
detections: DetectionSequence,
input_filename: str,
output_filename: str,
fps: int = 50,
first_frame: int = 0,
) -> None:
"""Draws the detection boxes over the frames of the original videos.
Each ID is shown in a different color.
If the detection sequence does not span the whole input video, the first and last
frames need to be specified.
Args:
detections: the detection sequence
input_filename: name of the input video file
output_filename: name of the output gif
fps: frames per second of the output gif
first_frame: frame number corresponding to detections[0].
"""
colors = get_color_palettes()
input_video = cv2.VideoCapture(input_filename)
i = -1
frames = []
while input_video.isOpened():
i += 1
ret, frame = input_video.read()
if (not ret) or i >= first_frame + len(detections):
break
if i < first_frame:
continue
framedets = detections[i - first_frame]
ids = [int(_id) for _id in framedets.ids]
boxes = [framedets.xyxy[k] for k in range(framedets.size)]
for _id, box in zip(ids, boxes, strict=True):
if _id == 0:
continue
color = 244 * colors[_id % MAX_COLORS]
frame = draw_detection(frame, box, str(_id), color)
frames.append(frame)
input_video.release()
imageio.mimsave(output_filename, cast("list[Any]", frames), fps=fps)
cv2.destroyAllWindows()