Upload 6 files
Browse files- .gitattributes +1 -0
- ax650/mixformer_v2.axmodel +3 -0
- car.avi +3 -0
- config.json +0 -0
- onnx/mixformer_v2_sim.onnx +3 -0
- run_mixformer2_axmodel.py +370 -0
- run_mixformer2_onnx.py +364 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
37 |
+
car.avi filter=lfs diff=lfs merge=lfs -text
|
ax650/mixformer_v2.axmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ae9b045601ccb716a7bcec3a72a601e279db421a98de779f0283d750694f84d
|
3 |
+
size 23041126
|
car.avi
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12085e86f5791f9936f382027d51ed02e2485636a03480aeaec2ac03e2bd2bd1
|
3 |
+
size 65474094
|
config.json
ADDED
File without changes
|
onnx/mixformer_v2_sim.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9863932c9d18e6c75adcbec1871d293440bd350bc08f76707ca8088312e2175c
|
3 |
+
size 64950790
|
run_mixformer2_axmodel.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
import cv2
|
7 |
+
import math
|
8 |
+
import glob
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import axengine as axe
|
12 |
+
from axengine import axclrt_provider_name, axengine_provider_name
|
13 |
+
|
14 |
+
def load_model(model_path: str | os.PathLike, selected_provider: str, selected_device_id: int = 0):
|
15 |
+
if selected_provider == 'AUTO':
|
16 |
+
# Use AUTO to let the pyengine choose the first available provider
|
17 |
+
return axe.InferenceSession(model_path)
|
18 |
+
|
19 |
+
providers = []
|
20 |
+
if selected_provider == axclrt_provider_name:
|
21 |
+
provider_options = {"device_id": selected_device_id}
|
22 |
+
providers.append((axclrt_provider_name, provider_options))
|
23 |
+
if selected_provider == axengine_provider_name:
|
24 |
+
providers.append(axengine_provider_name)
|
25 |
+
|
26 |
+
return axe.InferenceSession(model_path, providers=providers)
|
27 |
+
|
28 |
+
|
29 |
+
def get_frames(video_name):
|
30 |
+
"""获取视频帧
|
31 |
+
|
32 |
+
Args:
|
33 |
+
video_name (_type_): _description_
|
34 |
+
|
35 |
+
Yields:
|
36 |
+
_type_: _description_
|
37 |
+
"""
|
38 |
+
if not video_name:
|
39 |
+
rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
|
40 |
+
cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
|
41 |
+
|
42 |
+
# warmup
|
43 |
+
for i in range(5):
|
44 |
+
cap.read()
|
45 |
+
while True:
|
46 |
+
ret, frame = cap.read()
|
47 |
+
if ret:
|
48 |
+
# print('读取成功===>>>', frame.shape)
|
49 |
+
yield cv2.resize(frame,(800, 600))
|
50 |
+
else:
|
51 |
+
break
|
52 |
+
elif video_name.endswith('avi') or \
|
53 |
+
video_name.endswith('mp4'):
|
54 |
+
cap = cv2.VideoCapture(video_name)
|
55 |
+
while True:
|
56 |
+
ret, frame = cap.read()
|
57 |
+
if ret:
|
58 |
+
yield frame
|
59 |
+
else:
|
60 |
+
break
|
61 |
+
else:
|
62 |
+
images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
|
63 |
+
for img in images:
|
64 |
+
frame = cv2.imread(img)
|
65 |
+
yield frame
|
66 |
+
|
67 |
+
|
68 |
+
class Preprocessor_wo_mask(object):
|
69 |
+
def __init__(self):
|
70 |
+
self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)).astype(np.float32)
|
71 |
+
self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)).astype(np.float32)
|
72 |
+
|
73 |
+
def process(self, img_arr: np.ndarray):
|
74 |
+
# Deal with the image patch
|
75 |
+
img_tensor = img_arr.transpose((2, 0, 1)).reshape((1, 3, img_arr.shape[0], img_arr.shape[1])).astype(np.float32) / 255.0
|
76 |
+
img_tensor_norm = (img_tensor - self.mean) / self.std # (1,3,H,W)
|
77 |
+
return img_tensor_norm
|
78 |
+
|
79 |
+
|
80 |
+
class MFTrackerORT:
|
81 |
+
def __init__(self, model_path, fp16=False) -> None:
|
82 |
+
self.debug = True
|
83 |
+
self.gpu_id = 0
|
84 |
+
self.providers = ["CUDAExecutionProvider"]
|
85 |
+
self.provider_options = [{"device_id": str(self.gpu_id)}]
|
86 |
+
self.model_path = model_path
|
87 |
+
self.fp16 = fp16
|
88 |
+
|
89 |
+
self.init_track_net()
|
90 |
+
self.preprocessor = Preprocessor_wo_mask()
|
91 |
+
self.max_score_decay = 1.0
|
92 |
+
self.search_factor = 4.5
|
93 |
+
self.search_size = 224
|
94 |
+
self.template_factor = 2.0
|
95 |
+
self.template_size = 112
|
96 |
+
self.update_interval = 200
|
97 |
+
self.online_size = 1
|
98 |
+
|
99 |
+
def init_track_net(self):
|
100 |
+
"""使用设置的参数初始化tracker网络
|
101 |
+
"""
|
102 |
+
self.ax_session = load_model(self.model_path, selected_provider="AUTO")
|
103 |
+
|
104 |
+
def track_init(self, frame, target_pos=None, target_sz = None):
|
105 |
+
"""使用第一帧进行初始化
|
106 |
+
|
107 |
+
Args:
|
108 |
+
frame (_type_): _description_
|
109 |
+
target_pos (_type_, optional): _description_. Defaults to None.
|
110 |
+
target_sz (_type_, optional): _description_. Defaults to None.
|
111 |
+
"""
|
112 |
+
self.trace_list = []
|
113 |
+
try:
|
114 |
+
# [x, y, w, h]
|
115 |
+
init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
|
116 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
|
117 |
+
template = self.preprocessor.process(z_patch_arr)
|
118 |
+
self.template = template
|
119 |
+
self.online_template = template
|
120 |
+
|
121 |
+
self.online_state = init_state
|
122 |
+
self.online_image = frame
|
123 |
+
self.max_pred_score = -1.0
|
124 |
+
self.online_max_template = template
|
125 |
+
self.online_forget_id = 0
|
126 |
+
|
127 |
+
# save states
|
128 |
+
self.state = init_state
|
129 |
+
self.frame_id = 0
|
130 |
+
print(f"第一帧初始化完毕!")
|
131 |
+
except:
|
132 |
+
print(f"第一帧初始化异常!")
|
133 |
+
exit()
|
134 |
+
|
135 |
+
def track(self, image, info: dict = None):
|
136 |
+
H, W, _ = image.shape
|
137 |
+
self.frame_id += 1
|
138 |
+
x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
|
139 |
+
output_sz=self.search_size) # (x1, y1, w, h)
|
140 |
+
search = self.preprocessor.process(x_patch_arr)
|
141 |
+
|
142 |
+
# compute ONNX Runtime output prediction
|
143 |
+
ort_inputs = {'img_t': self.template, 'img_ot': self.online_template, 'img_search': search}
|
144 |
+
|
145 |
+
ort_outs = self.ax_session.run(None, ort_inputs)
|
146 |
+
|
147 |
+
# print(f">>> lenght trt_outputs: {ort_outs}")
|
148 |
+
pred_boxes = ort_outs[0]
|
149 |
+
pred_score = ort_outs[1]
|
150 |
+
# print(f">>> box and score: {pred_boxes} {pred_score}")
|
151 |
+
# Baseline: Take the mean of all pred boxes as the final result
|
152 |
+
pred_box = (np.mean(pred_boxes, axis=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
153 |
+
# get the final box result
|
154 |
+
self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
|
155 |
+
|
156 |
+
self.max_pred_score = self.max_pred_score * self.max_score_decay
|
157 |
+
# update template
|
158 |
+
if pred_score > 0.5 and pred_score > self.max_pred_score:
|
159 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
|
160 |
+
self.template_factor,
|
161 |
+
output_sz=self.template_size) # (x1, y1, w, h)
|
162 |
+
self.online_max_template = self.preprocessor.process(z_patch_arr)
|
163 |
+
self.max_pred_score = pred_score
|
164 |
+
|
165 |
+
|
166 |
+
if self.frame_id % self.update_interval == 0:
|
167 |
+
if self.online_size == 1:
|
168 |
+
self.online_template = self.online_max_template
|
169 |
+
else:
|
170 |
+
self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
|
171 |
+
self.online_forget_id = (self.online_forget_id + 1) % self.online_size
|
172 |
+
|
173 |
+
self.max_pred_score = -1
|
174 |
+
self.online_max_template = self.template
|
175 |
+
|
176 |
+
# for debug
|
177 |
+
if self.debug:
|
178 |
+
x1, y1, w, h = self.state
|
179 |
+
# image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
180 |
+
cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
|
181 |
+
|
182 |
+
return {"target_bbox": self.state, "conf_score": pred_score}
|
183 |
+
|
184 |
+
def map_box_back(self, pred_box: list, resize_factor: float):
|
185 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
186 |
+
cx, cy, w, h = pred_box
|
187 |
+
half_side = 0.5 * self.search_size / resize_factor
|
188 |
+
cx_real = cx + (cx_prev - half_side)
|
189 |
+
cy_real = cy + (cy_prev - half_side)
|
190 |
+
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
191 |
+
|
192 |
+
def map_box_back_batch(self, pred_box: np.ndarray, resize_factor: float):
|
193 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
194 |
+
cx, cy, w, h = pred_box.T # (N,4) --> (N,)
|
195 |
+
half_side = 0.5 * self.search_size / resize_factor
|
196 |
+
cx_real = cx + (cx_prev - half_side)
|
197 |
+
cy_real = cy + (cy_prev - half_side)
|
198 |
+
return np.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], axis=-1)
|
199 |
+
|
200 |
+
def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
|
201 |
+
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
202 |
+
|
203 |
+
args:
|
204 |
+
im - cv image
|
205 |
+
target_bb - target box [x, y, w, h]
|
206 |
+
search_area_factor - Ratio of crop size to target size
|
207 |
+
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
208 |
+
|
209 |
+
returns:
|
210 |
+
cv image - extracted crop
|
211 |
+
float - the factor by which the crop has been resized to make the crop size equal output_size
|
212 |
+
"""
|
213 |
+
if not isinstance(target_bb, list):
|
214 |
+
x, y, w, h = target_bb.tolist()
|
215 |
+
else:
|
216 |
+
x, y, w, h = target_bb
|
217 |
+
# Crop image
|
218 |
+
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
219 |
+
|
220 |
+
if crop_sz < 1:
|
221 |
+
raise Exception('Too small bounding box.')
|
222 |
+
|
223 |
+
x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
|
224 |
+
x2 = int(x1 + crop_sz)
|
225 |
+
|
226 |
+
y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
|
227 |
+
y2 = int(y1 + crop_sz)
|
228 |
+
|
229 |
+
x1_pad = int(max(0, -x1))
|
230 |
+
x2_pad = int(max(x2 - im.shape[1] + 1, 0))
|
231 |
+
|
232 |
+
y1_pad = int(max(0, -y1))
|
233 |
+
y2_pad = int(max(y2 - im.shape[0] + 1, 0))
|
234 |
+
|
235 |
+
# Crop target
|
236 |
+
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
237 |
+
if mask is not None:
|
238 |
+
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
239 |
+
|
240 |
+
# Pad
|
241 |
+
im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
|
242 |
+
# deal with attention mask
|
243 |
+
H, W, _ = im_crop_padded.shape
|
244 |
+
att_mask = np.ones((H,W))
|
245 |
+
end_x, end_y = -x2_pad, -y2_pad
|
246 |
+
if y2_pad == 0:
|
247 |
+
end_y = None
|
248 |
+
if x2_pad == 0:
|
249 |
+
end_x = None
|
250 |
+
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
251 |
+
if mask is not None:
|
252 |
+
mask_crop_padded = cv2.copyMakeBorder(mask_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
|
253 |
+
|
254 |
+
if output_sz is not None:
|
255 |
+
resize_factor = output_sz / crop_sz
|
256 |
+
im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
|
257 |
+
att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
258 |
+
if mask is None:
|
259 |
+
return im_crop_padded, resize_factor, att_mask
|
260 |
+
mask_crop_padded = \
|
261 |
+
mask_crop_padded = cv2.resize(mask_crop_padded, (output_sz, output_sz))
|
262 |
+
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
263 |
+
|
264 |
+
else:
|
265 |
+
if mask is None:
|
266 |
+
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
267 |
+
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
268 |
+
|
269 |
+
def clip_box(self, box: list, H, W, margin=0):
|
270 |
+
x1, y1, w, h = box
|
271 |
+
x2, y2 = x1 + w, y1 + h
|
272 |
+
x1 = min(max(0, x1), W-margin)
|
273 |
+
x2 = min(max(margin, x2), W)
|
274 |
+
y1 = min(max(0, y1), H-margin)
|
275 |
+
y2 = min(max(margin, y2), H)
|
276 |
+
w = max(margin, x2-x1)
|
277 |
+
h = max(margin, y2-y1)
|
278 |
+
return [x1, y1, w, h]
|
279 |
+
|
280 |
+
|
281 |
+
def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
|
282 |
+
Tracker = MFTrackerORT(model_path = model_path, fp16=False)
|
283 |
+
first_frame = True
|
284 |
+
Tracker.video_name = frame_path
|
285 |
+
|
286 |
+
frame_id = 0
|
287 |
+
total_time = 0
|
288 |
+
for frame in get_frames(Tracker.video_name):
|
289 |
+
# print(f"frame shape {frame.shape}")
|
290 |
+
|
291 |
+
# 如果超过了指定的帧数限制,则跳出循环
|
292 |
+
if repeat is not None and frame_id >= repeat:
|
293 |
+
print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
|
294 |
+
break
|
295 |
+
|
296 |
+
tic = cv2.getTickCount()
|
297 |
+
if first_frame:
|
298 |
+
# x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
|
299 |
+
x, y, w, h = 1079, 482, 99, 106
|
300 |
+
|
301 |
+
target_pos = [x, y]
|
302 |
+
target_sz = [w, h]
|
303 |
+
print('====================type=================', target_pos, type(target_pos), type(target_sz))
|
304 |
+
Tracker.track_init(frame, target_pos, target_sz)
|
305 |
+
first_frame = False
|
306 |
+
else:
|
307 |
+
state = Tracker.track(frame)
|
308 |
+
frame_id += 1
|
309 |
+
|
310 |
+
os.makedirs('axmodel_output', exist_ok=True)
|
311 |
+
cv2.imwrite(f'axmodel_output/{str(frame_id)}.png', frame)
|
312 |
+
|
313 |
+
toc = cv2.getTickCount() - tic
|
314 |
+
toc = int(1 / (toc / cv2.getTickFrequency()))
|
315 |
+
total_time += toc
|
316 |
+
print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
|
317 |
+
|
318 |
+
print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
class ExampleParser(argparse.ArgumentParser):
|
323 |
+
def error(self, message):
|
324 |
+
self.print_usage(sys.stderr)
|
325 |
+
print(f"\nError: {message}")
|
326 |
+
print("\nExample usage:")
|
327 |
+
print(" python3 run_mixformer2_axmodel.py -m <model_file> -f <frame_file>")
|
328 |
+
print(" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi")
|
329 |
+
print(
|
330 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
|
331 |
+
print(
|
332 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
|
333 |
+
sys.exit(1)
|
334 |
+
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
ap = ExampleParser()
|
338 |
+
ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
|
339 |
+
ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
|
340 |
+
ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
|
341 |
+
ap.add_argument(
|
342 |
+
'-p',
|
343 |
+
'--provider',
|
344 |
+
type=str,
|
345 |
+
choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"],
|
346 |
+
help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"',
|
347 |
+
default='AUTO'
|
348 |
+
)
|
349 |
+
ap.add_argument(
|
350 |
+
'-d',
|
351 |
+
'--device-id',
|
352 |
+
type=int,
|
353 |
+
help=R'axclrt device index, depends on how many cards inserted',
|
354 |
+
default=0
|
355 |
+
)
|
356 |
+
args = ap.parse_args()
|
357 |
+
|
358 |
+
model_file = args.model_path
|
359 |
+
frame_file = args.frame_path
|
360 |
+
|
361 |
+
# check if the model and image exist
|
362 |
+
assert os.path.exists(model_file), f"model file path {model_file} does not exist"
|
363 |
+
assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"
|
364 |
+
|
365 |
+
repeat = args.repeat
|
366 |
+
|
367 |
+
provider = args.provider
|
368 |
+
device_id = args.device_id
|
369 |
+
|
370 |
+
main(model_file, frame_file, repeat, provider, device_id)
|
run_mixformer2_onnx.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
import cv2
|
7 |
+
import glob
|
8 |
+
import onnxruntime
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import math
|
12 |
+
|
13 |
+
prj_path = os.path.join(os.path.dirname(__file__), '..')
|
14 |
+
if prj_path not in sys.path:
|
15 |
+
sys.path.append(prj_path)
|
16 |
+
|
17 |
+
def get_frames(video_name):
|
18 |
+
"""获取视频帧
|
19 |
+
|
20 |
+
Args:
|
21 |
+
video_name (_type_): _description_
|
22 |
+
|
23 |
+
Yields:
|
24 |
+
_type_: _description_
|
25 |
+
"""
|
26 |
+
if not video_name:
|
27 |
+
rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
|
28 |
+
cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
|
29 |
+
|
30 |
+
# warmup
|
31 |
+
for i in range(5):
|
32 |
+
cap.read()
|
33 |
+
while True:
|
34 |
+
ret, frame = cap.read()
|
35 |
+
if ret:
|
36 |
+
# print('读取成功===>>>', frame.shape)
|
37 |
+
yield cv2.resize(frame,(800, 600))
|
38 |
+
else:
|
39 |
+
break
|
40 |
+
elif video_name.endswith('avi') or \
|
41 |
+
video_name.endswith('mp4'):
|
42 |
+
cap = cv2.VideoCapture(video_name)
|
43 |
+
while True:
|
44 |
+
ret, frame = cap.read()
|
45 |
+
if ret:
|
46 |
+
yield frame
|
47 |
+
else:
|
48 |
+
break
|
49 |
+
else:
|
50 |
+
images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
|
51 |
+
for img in images:
|
52 |
+
frame = cv2.imread(img)
|
53 |
+
yield frame
|
54 |
+
|
55 |
+
class Preprocessor_wo_mask(object):
|
56 |
+
def __init__(self):
|
57 |
+
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1))
|
58 |
+
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1))
|
59 |
+
|
60 |
+
def process(self, img_arr: np.ndarray):
|
61 |
+
# Deal with the image patch
|
62 |
+
img_tensor = torch.tensor(img_arr).float().permute((2,0,1)).unsqueeze(dim=0)
|
63 |
+
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
|
64 |
+
return img_tensor_norm.contiguous()
|
65 |
+
|
66 |
+
class MFTrackerORT:
|
67 |
+
def __init__(self, model_path, fp16=False) -> None:
|
68 |
+
self.debug = True
|
69 |
+
self.gpu_id = 0
|
70 |
+
self.providers = ["CUDAExecutionProvider"]
|
71 |
+
self.provider_options = [{"device_id": str(self.gpu_id)}]
|
72 |
+
self.model_path = model_path
|
73 |
+
self.fp16 = fp16
|
74 |
+
|
75 |
+
self.init_track_net()
|
76 |
+
self.preprocessor = Preprocessor_wo_mask()
|
77 |
+
self.max_score_decay = 1.0
|
78 |
+
self.search_factor = 4.5
|
79 |
+
self.search_size = 224
|
80 |
+
self.template_factor = 2.0
|
81 |
+
self.template_size = 112
|
82 |
+
self.update_interval = 200
|
83 |
+
self.online_size = 1
|
84 |
+
|
85 |
+
def init_track_net(self):
|
86 |
+
"""使用设置的参数初始化tracker网络
|
87 |
+
"""
|
88 |
+
self.ort_session = onnxruntime.InferenceSession(self.model_path, providers=self.providers, provider_options=self.provider_options)
|
89 |
+
|
90 |
+
def track_init(self, frame, target_pos=None, target_sz = None):
|
91 |
+
"""使用第一帧进行初始化
|
92 |
+
|
93 |
+
Args:
|
94 |
+
frame (_type_): _description_
|
95 |
+
target_pos (_type_, optional): _description_. Defaults to None.
|
96 |
+
target_sz (_type_, optional): _description_. Defaults to None.
|
97 |
+
"""
|
98 |
+
self.trace_list = []
|
99 |
+
try:
|
100 |
+
# [x, y, w, h]
|
101 |
+
init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
|
102 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
|
103 |
+
template = self.preprocessor.process(z_patch_arr)
|
104 |
+
self.template = template
|
105 |
+
self.online_template = template
|
106 |
+
|
107 |
+
self.online_state = init_state
|
108 |
+
self.online_image = frame
|
109 |
+
self.max_pred_score = -1.0
|
110 |
+
self.online_max_template = template
|
111 |
+
self.online_forget_id = 0
|
112 |
+
|
113 |
+
# save states
|
114 |
+
self.state = init_state
|
115 |
+
self.frame_id = 0
|
116 |
+
print(f"第一帧初始化完毕!")
|
117 |
+
except:
|
118 |
+
print(f"第一帧初始化异常!")
|
119 |
+
exit()
|
120 |
+
|
121 |
+
def track(self, image, info: dict = None):
|
122 |
+
H, W, _ = image.shape
|
123 |
+
self.frame_id += 1
|
124 |
+
x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
|
125 |
+
output_sz=self.search_size) # (x1, y1, w, h)
|
126 |
+
search = self.preprocessor.process(x_patch_arr)
|
127 |
+
|
128 |
+
# compute ONNX Runtime output prediction
|
129 |
+
ort_inputs = {'img_t': self.to_numpy(self.template), 'img_ot': self.to_numpy(self.online_template), 'img_search': self.to_numpy(search)}
|
130 |
+
|
131 |
+
ort_outs = self.ort_session.run(None, ort_inputs)
|
132 |
+
|
133 |
+
# print(f">>> lenght trt_outputs: {ort_outs}")
|
134 |
+
pred_boxes = torch.from_numpy(ort_outs[0])
|
135 |
+
pred_score = torch.from_numpy(ort_outs[1])
|
136 |
+
# print(f">>> box and score: {pred_boxes} {pred_score}")
|
137 |
+
# Baseline: Take the mean of all pred boxes as the final result
|
138 |
+
pred_box = (pred_boxes.mean(dim=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
139 |
+
# get the final box result
|
140 |
+
self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
|
141 |
+
|
142 |
+
self.max_pred_score = self.max_pred_score * self.max_score_decay
|
143 |
+
# update template
|
144 |
+
if pred_score > 0.5 and pred_score > self.max_pred_score:
|
145 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
|
146 |
+
self.template_factor,
|
147 |
+
output_sz=self.template_size) # (x1, y1, w, h)
|
148 |
+
self.online_max_template = self.preprocessor.process(z_patch_arr)
|
149 |
+
self.max_pred_score = pred_score
|
150 |
+
|
151 |
+
|
152 |
+
if self.frame_id % self.update_interval == 0:
|
153 |
+
if self.online_size == 1:
|
154 |
+
self.online_template = self.online_max_template
|
155 |
+
else:
|
156 |
+
self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
|
157 |
+
self.online_forget_id = (self.online_forget_id + 1) % self.online_size
|
158 |
+
|
159 |
+
self.max_pred_score = -1
|
160 |
+
self.online_max_template = self.template
|
161 |
+
|
162 |
+
# for debug
|
163 |
+
if self.debug:
|
164 |
+
x1, y1, w, h = self.state
|
165 |
+
# image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
166 |
+
cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
|
167 |
+
|
168 |
+
return {"target_bbox": self.state, "conf_score": pred_score}
|
169 |
+
|
170 |
+
def map_box_back(self, pred_box: list, resize_factor: float):
|
171 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
172 |
+
cx, cy, w, h = pred_box
|
173 |
+
half_side = 0.5 * self.search_size / resize_factor
|
174 |
+
cx_real = cx + (cx_prev - half_side)
|
175 |
+
cy_real = cy + (cy_prev - half_side)
|
176 |
+
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
177 |
+
|
178 |
+
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
|
179 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
180 |
+
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
|
181 |
+
half_side = 0.5 * self.search_size / resize_factor
|
182 |
+
cx_real = cx + (cx_prev - half_side)
|
183 |
+
cy_real = cy + (cy_prev - half_side)
|
184 |
+
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
|
185 |
+
|
186 |
+
def to_numpy(self, tensor):
|
187 |
+
if self.fp16:
|
188 |
+
return tensor.detach().cpu().half().numpy() if tensor.requires_grad else tensor.cpu().half().numpy()
|
189 |
+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
190 |
+
|
191 |
+
def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
|
192 |
+
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
193 |
+
|
194 |
+
args:
|
195 |
+
im - cv image
|
196 |
+
target_bb - target box [x, y, w, h]
|
197 |
+
search_area_factor - Ratio of crop size to target size
|
198 |
+
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
199 |
+
|
200 |
+
returns:
|
201 |
+
cv image - extracted crop
|
202 |
+
float - the factor by which the crop has been resized to make the crop size equal output_size
|
203 |
+
"""
|
204 |
+
if not isinstance(target_bb, list):
|
205 |
+
x, y, w, h = target_bb.tolist()
|
206 |
+
else:
|
207 |
+
x, y, w, h = target_bb
|
208 |
+
# Crop image
|
209 |
+
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
210 |
+
|
211 |
+
if crop_sz < 1:
|
212 |
+
raise Exception('Too small bounding box.')
|
213 |
+
|
214 |
+
x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
|
215 |
+
x2 = int(x1 + crop_sz)
|
216 |
+
|
217 |
+
y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
|
218 |
+
y2 = int(y1 + crop_sz)
|
219 |
+
|
220 |
+
x1_pad = int(max(0, -x1))
|
221 |
+
x2_pad = int(max(x2 - im.shape[1] + 1, 0))
|
222 |
+
|
223 |
+
y1_pad = int(max(0, -y1))
|
224 |
+
y2_pad = int(max(y2 - im.shape[0] + 1, 0))
|
225 |
+
|
226 |
+
# Crop target
|
227 |
+
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
228 |
+
if mask is not None:
|
229 |
+
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
230 |
+
|
231 |
+
# Pad
|
232 |
+
im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
|
233 |
+
# deal with attention mask
|
234 |
+
H, W, _ = im_crop_padded.shape
|
235 |
+
att_mask = np.ones((H,W))
|
236 |
+
end_x, end_y = -x2_pad, -y2_pad
|
237 |
+
if y2_pad == 0:
|
238 |
+
end_y = None
|
239 |
+
if x2_pad == 0:
|
240 |
+
end_x = None
|
241 |
+
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
242 |
+
if mask is not None:
|
243 |
+
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
|
244 |
+
|
245 |
+
|
246 |
+
if output_sz is not None:
|
247 |
+
resize_factor = output_sz / crop_sz
|
248 |
+
im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
|
249 |
+
att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
250 |
+
if mask is None:
|
251 |
+
return im_crop_padded, resize_factor, att_mask
|
252 |
+
mask_crop_padded = \
|
253 |
+
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
|
254 |
+
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
255 |
+
|
256 |
+
else:
|
257 |
+
if mask is None:
|
258 |
+
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
259 |
+
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
260 |
+
|
261 |
+
def clip_box(self, box: list, H, W, margin=0):
|
262 |
+
x1, y1, w, h = box
|
263 |
+
x2, y2 = x1 + w, y1 + h
|
264 |
+
x1 = min(max(0, x1), W-margin)
|
265 |
+
x2 = min(max(margin, x2), W)
|
266 |
+
y1 = min(max(0, y1), H-margin)
|
267 |
+
y2 = min(max(margin, y2), H)
|
268 |
+
w = max(margin, x2-x1)
|
269 |
+
h = max(margin, y2-y1)
|
270 |
+
return [x1, y1, w, h]
|
271 |
+
|
272 |
+
|
273 |
+
def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
|
274 |
+
Tracker = MFTrackerORT(model_path = model_path, fp16=False)
|
275 |
+
first_frame = True
|
276 |
+
Tracker.video_name = frame_path
|
277 |
+
|
278 |
+
frame_id = 0
|
279 |
+
total_time = 0
|
280 |
+
for frame in get_frames(Tracker.video_name):
|
281 |
+
# print(f"frame shape {frame.shape}")
|
282 |
+
|
283 |
+
# 如果超过了指定的帧数限制,则跳出循环
|
284 |
+
if repeat is not None and frame_id >= repeat:
|
285 |
+
print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
|
286 |
+
break
|
287 |
+
|
288 |
+
tic = cv2.getTickCount()
|
289 |
+
if first_frame:
|
290 |
+
# x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
|
291 |
+
# left, top, width, height
|
292 |
+
x, y, w, h = 1079, 482, 99, 106
|
293 |
+
|
294 |
+
target_pos = [x, y]
|
295 |
+
target_sz = [w, h]
|
296 |
+
print('====================type=================', target_pos, type(target_pos), type(target_sz))
|
297 |
+
Tracker.track_init(frame, target_pos, target_sz)
|
298 |
+
first_frame = False
|
299 |
+
else:
|
300 |
+
state = Tracker.track(frame)
|
301 |
+
frame_id += 1
|
302 |
+
|
303 |
+
os.makedirs('onnx_output', exist_ok=True)
|
304 |
+
cv2.imwrite(f'onnx_output/{str(frame_id)}.png', frame)
|
305 |
+
|
306 |
+
toc = cv2.getTickCount() - tic
|
307 |
+
toc = int(1 / (toc / cv2.getTickFrequency()))
|
308 |
+
total_time += toc
|
309 |
+
print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
|
310 |
+
|
311 |
+
print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
class ExampleParser(argparse.ArgumentParser):
|
316 |
+
def error(self, message):
|
317 |
+
self.print_usage(sys.stderr)
|
318 |
+
print(f"\nError: {message}")
|
319 |
+
print("\nExample usage:")
|
320 |
+
print(" python3 run_mixformer2_onnx.py -m <model_file> -f <frame_file>")
|
321 |
+
print(" python3 run_mixformer2_onnx.py -m mixformer_v2_sim.onnx -f car.avi")
|
322 |
+
print(
|
323 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
|
324 |
+
print(
|
325 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
|
326 |
+
sys.exit(1)
|
327 |
+
|
328 |
+
|
329 |
+
if __name__ == "__main__":
|
330 |
+
ap = ExampleParser()
|
331 |
+
ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
|
332 |
+
ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
|
333 |
+
ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
|
334 |
+
ap.add_argument(
|
335 |
+
'-p',
|
336 |
+
'--provider',
|
337 |
+
type=str,
|
338 |
+
choices=["AUTO", f"CUDAExecutionProvider", f"CPUExecutionProvider"],
|
339 |
+
help=f'"AUTO", "CUDAExecutionProvider", "CPUExecutionProvider"',
|
340 |
+
default='AUTO'
|
341 |
+
)
|
342 |
+
ap.add_argument(
|
343 |
+
'-d',
|
344 |
+
'--device-id',
|
345 |
+
type=int,
|
346 |
+
help=R'CUDA device index, depends on how many cards inserted',
|
347 |
+
default=0
|
348 |
+
)
|
349 |
+
args = ap.parse_args()
|
350 |
+
|
351 |
+
model_file = args.model_path
|
352 |
+
frame_file = args.frame_path
|
353 |
+
|
354 |
+
# check if the model and image exist
|
355 |
+
assert os.path.exists(model_file), f"model file path {model_file} does not exist"
|
356 |
+
assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"
|
357 |
+
|
358 |
+
repeat = args.repeat
|
359 |
+
|
360 |
+
provider = args.provider
|
361 |
+
device_id = args.device_id
|
362 |
+
|
363 |
+
main(model_file, frame_file, repeat, provider, device_id)
|
364 |
+
|