-
Notifications
You must be signed in to change notification settings - Fork 0
/
detector2cvat.py
137 lines (112 loc) · 4.88 KB
/
detector2cvat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import argparse
import cv2
from tqdm import tqdm
from kabr_tools.utils.yolo import YOLOv8
from kabr_tools.utils.tracker import Tracker, Tracks
from kabr_tools.utils.object import Object
from kabr_tools.utils.draw import Draw
def detector2cvat(path_to_videos: str, path_to_save: str, show: bool) -> None:
"""
Detect objects with Ultralytics YOLO detections, apply SORT tracking and convert tracks to CVAT format.
Parameters:
path_to_videos - str. Path to the folder containing videos.
path_to_save - str. Path to the folder to save output xml & mp4 files.
show - bool. Flag to display detector's visualization.
"""
videos = []
for root, dirs, files in os.walk(path_to_videos):
for file in files:
if os.path.splitext(file)[1] == ".mp4":
folder = root.split("/")[-1]
if folder.startswith("!") or file.startswith("!"):
continue
videos.append(f"{root}/{file}")
yolo = YOLOv8(weights="yolov8x.pt", imgsz=3840, conf=0.5)
for i, video in enumerate(videos):
try:
name = os.path.splitext(video.split("/")[-1])[0]
output_folder = path_to_save + os.sep + "/".join(os.path.splitext(video)[0].split("/")[-3:-1])
output_path = f"{output_folder}/{name}.xml"
print(f"{i + 1}/{len(videos)}: {video} -> {output_path}")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
vc = cv2.VideoCapture(video)
size = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(vc.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vc.get(cv2.CAP_PROP_FRAME_HEIGHT))
vw = cv2.VideoWriter(f"{output_folder}/{name}_demo.mp4", cv2.VideoWriter_fourcc("m", "p", "4", "v"),
29.97, (width, height))
max_disappeared = 40
tracker = Tracker(max_disappeared=max_disappeared, max_distance=300)
tracks = Tracks(max_disappeared=max_disappeared, interpolation=True,
video_name=name, video_size=size, video_width=width, video_height=height)
index = 0
vc.set(cv2.CAP_PROP_POS_FRAMES, index)
pbar = tqdm(total=size)
while vc.isOpened():
returned, frame = vc.read()
if returned:
visualization = frame.copy()
predictions = yolo.forward(frame)
centroids = []
attributes = []
for prediction in predictions:
attribute = {}
centroids.append(YOLOv8.get_centroid(prediction[0]))
attribute["box"] = prediction[0]
attribute["confidence"] = prediction[1]
attribute["label"] = prediction[2]
attributes.append(attribute)
objects, colors = tracker.update(centroids)
objects = Object.object_factory(objects, centroids, colors, attributes=attributes)
tracks.update(objects, index)
for object in objects:
Draw.track(visualization, tracks[object.object_id].centroids, object.color, 20)
Draw.bounding_box(visualization, object)
Draw.object_id(visualization, object)
cv2.putText(visualization, f"Frame: {index}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
0.8, (255, 255, 255), 3, cv2.LINE_AA)
if show:
cv2.imshow("detector2cvat", cv2.resize(
visualization, (int(width // 2.5), int(height // 2.5))))
vw.write(visualization)
key = cv2.waitKey(1)
index += 1
pbar.update(1)
if key == 27:
break
else:
break
pbar.close()
vc.release()
vw.release()
cv2.destroyAllWindows()
tracks.save(output_path, "cvat")
except:
print("Something went wrong...")
def parse_args() -> argparse.Namespace:
local_parser = argparse.ArgumentParser()
local_parser.add_argument(
"--video",
type=str,
help="path to folder containing videos",
required=True
)
local_parser.add_argument(
"--save",
type=str,
help="path to save output xml & mp4 files",
required=True
)
local_parser.add_argument(
"--imshow",
action="store_true",
help="flag to display detector's visualization"
)
return local_parser.parse_args()
def main() -> None:
args = parse_args()
detector2cvat(args.video, args.save, args.imshow)
if __name__ == "__main__":
main()