qqc1989 commited on
Commit
4494462
·
verified ·
1 Parent(s): f3741a9

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.axmodel filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ car.avi filter=lfs diff=lfs merge=lfs -text
ax650/mixformer_v2.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ae9b045601ccb716a7bcec3a72a601e279db421a98de779f0283d750694f84d
3
+ size 23041126
car.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12085e86f5791f9936f382027d51ed02e2485636a03480aeaec2ac03e2bd2bd1
3
+ size 65474094
config.json ADDED
File without changes
onnx/mixformer_v2_sim.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9863932c9d18e6c75adcbec1871d293440bd350bc08f76707ca8088312e2175c
3
+ size 64950790
run_mixformer2_axmodel.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import sys
5
+ import time
6
+ import cv2
7
+ import math
8
+ import glob
9
+ import numpy as np
10
+
11
+ import axengine as axe
12
+ from axengine import axclrt_provider_name, axengine_provider_name
13
+
14
+ def load_model(model_path: str | os.PathLike, selected_provider: str, selected_device_id: int = 0):
15
+ if selected_provider == 'AUTO':
16
+ # Use AUTO to let the pyengine choose the first available provider
17
+ return axe.InferenceSession(model_path)
18
+
19
+ providers = []
20
+ if selected_provider == axclrt_provider_name:
21
+ provider_options = {"device_id": selected_device_id}
22
+ providers.append((axclrt_provider_name, provider_options))
23
+ if selected_provider == axengine_provider_name:
24
+ providers.append(axengine_provider_name)
25
+
26
+ return axe.InferenceSession(model_path, providers=providers)
27
+
28
+
29
+ def get_frames(video_name):
30
+ """获取视频帧
31
+
32
+ Args:
33
+ video_name (_type_): _description_
34
+
35
+ Yields:
36
+ _type_: _description_
37
+ """
38
+ if not video_name:
39
+ rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
40
+ cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
41
+
42
+ # warmup
43
+ for i in range(5):
44
+ cap.read()
45
+ while True:
46
+ ret, frame = cap.read()
47
+ if ret:
48
+ # print('读取成功===>>>', frame.shape)
49
+ yield cv2.resize(frame,(800, 600))
50
+ else:
51
+ break
52
+ elif video_name.endswith('avi') or \
53
+ video_name.endswith('mp4'):
54
+ cap = cv2.VideoCapture(video_name)
55
+ while True:
56
+ ret, frame = cap.read()
57
+ if ret:
58
+ yield frame
59
+ else:
60
+ break
61
+ else:
62
+ images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
63
+ for img in images:
64
+ frame = cv2.imread(img)
65
+ yield frame
66
+
67
+
68
+ class Preprocessor_wo_mask(object):
69
+ def __init__(self):
70
+ self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)).astype(np.float32)
71
+ self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)).astype(np.float32)
72
+
73
+ def process(self, img_arr: np.ndarray):
74
+ # Deal with the image patch
75
+ img_tensor = img_arr.transpose((2, 0, 1)).reshape((1, 3, img_arr.shape[0], img_arr.shape[1])).astype(np.float32) / 255.0
76
+ img_tensor_norm = (img_tensor - self.mean) / self.std # (1,3,H,W)
77
+ return img_tensor_norm
78
+
79
+
80
+ class MFTrackerORT:
81
+ def __init__(self, model_path, fp16=False) -> None:
82
+ self.debug = True
83
+ self.gpu_id = 0
84
+ self.providers = ["CUDAExecutionProvider"]
85
+ self.provider_options = [{"device_id": str(self.gpu_id)}]
86
+ self.model_path = model_path
87
+ self.fp16 = fp16
88
+
89
+ self.init_track_net()
90
+ self.preprocessor = Preprocessor_wo_mask()
91
+ self.max_score_decay = 1.0
92
+ self.search_factor = 4.5
93
+ self.search_size = 224
94
+ self.template_factor = 2.0
95
+ self.template_size = 112
96
+ self.update_interval = 200
97
+ self.online_size = 1
98
+
99
+ def init_track_net(self):
100
+ """使用设置的参数初始化tracker网络
101
+ """
102
+ self.ax_session = load_model(self.model_path, selected_provider="AUTO")
103
+
104
+ def track_init(self, frame, target_pos=None, target_sz = None):
105
+ """使用第一帧进行初始化
106
+
107
+ Args:
108
+ frame (_type_): _description_
109
+ target_pos (_type_, optional): _description_. Defaults to None.
110
+ target_sz (_type_, optional): _description_. Defaults to None.
111
+ """
112
+ self.trace_list = []
113
+ try:
114
+ # [x, y, w, h]
115
+ init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
116
+ z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
117
+ template = self.preprocessor.process(z_patch_arr)
118
+ self.template = template
119
+ self.online_template = template
120
+
121
+ self.online_state = init_state
122
+ self.online_image = frame
123
+ self.max_pred_score = -1.0
124
+ self.online_max_template = template
125
+ self.online_forget_id = 0
126
+
127
+ # save states
128
+ self.state = init_state
129
+ self.frame_id = 0
130
+ print(f"第一帧初始化完毕!")
131
+ except:
132
+ print(f"第一帧初始化异常!")
133
+ exit()
134
+
135
+ def track(self, image, info: dict = None):
136
+ H, W, _ = image.shape
137
+ self.frame_id += 1
138
+ x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
139
+ output_sz=self.search_size) # (x1, y1, w, h)
140
+ search = self.preprocessor.process(x_patch_arr)
141
+
142
+ # compute ONNX Runtime output prediction
143
+ ort_inputs = {'img_t': self.template, 'img_ot': self.online_template, 'img_search': search}
144
+
145
+ ort_outs = self.ax_session.run(None, ort_inputs)
146
+
147
+ # print(f">>> lenght trt_outputs: {ort_outs}")
148
+ pred_boxes = ort_outs[0]
149
+ pred_score = ort_outs[1]
150
+ # print(f">>> box and score: {pred_boxes} {pred_score}")
151
+ # Baseline: Take the mean of all pred boxes as the final result
152
+ pred_box = (np.mean(pred_boxes, axis=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
153
+ # get the final box result
154
+ self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
155
+
156
+ self.max_pred_score = self.max_pred_score * self.max_score_decay
157
+ # update template
158
+ if pred_score > 0.5 and pred_score > self.max_pred_score:
159
+ z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
160
+ self.template_factor,
161
+ output_sz=self.template_size) # (x1, y1, w, h)
162
+ self.online_max_template = self.preprocessor.process(z_patch_arr)
163
+ self.max_pred_score = pred_score
164
+
165
+
166
+ if self.frame_id % self.update_interval == 0:
167
+ if self.online_size == 1:
168
+ self.online_template = self.online_max_template
169
+ else:
170
+ self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
171
+ self.online_forget_id = (self.online_forget_id + 1) % self.online_size
172
+
173
+ self.max_pred_score = -1
174
+ self.online_max_template = self.template
175
+
176
+ # for debug
177
+ if self.debug:
178
+ x1, y1, w, h = self.state
179
+ # image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
180
+ cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
181
+
182
+ return {"target_bbox": self.state, "conf_score": pred_score}
183
+
184
+ def map_box_back(self, pred_box: list, resize_factor: float):
185
+ cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
186
+ cx, cy, w, h = pred_box
187
+ half_side = 0.5 * self.search_size / resize_factor
188
+ cx_real = cx + (cx_prev - half_side)
189
+ cy_real = cy + (cy_prev - half_side)
190
+ return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
191
+
192
+ def map_box_back_batch(self, pred_box: np.ndarray, resize_factor: float):
193
+ cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
194
+ cx, cy, w, h = pred_box.T # (N,4) --> (N,)
195
+ half_side = 0.5 * self.search_size / resize_factor
196
+ cx_real = cx + (cx_prev - half_side)
197
+ cy_real = cy + (cy_prev - half_side)
198
+ return np.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], axis=-1)
199
+
200
+ def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
201
+ """ Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
202
+
203
+ args:
204
+ im - cv image
205
+ target_bb - target box [x, y, w, h]
206
+ search_area_factor - Ratio of crop size to target size
207
+ output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
208
+
209
+ returns:
210
+ cv image - extracted crop
211
+ float - the factor by which the crop has been resized to make the crop size equal output_size
212
+ """
213
+ if not isinstance(target_bb, list):
214
+ x, y, w, h = target_bb.tolist()
215
+ else:
216
+ x, y, w, h = target_bb
217
+ # Crop image
218
+ crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
219
+
220
+ if crop_sz < 1:
221
+ raise Exception('Too small bounding box.')
222
+
223
+ x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
224
+ x2 = int(x1 + crop_sz)
225
+
226
+ y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
227
+ y2 = int(y1 + crop_sz)
228
+
229
+ x1_pad = int(max(0, -x1))
230
+ x2_pad = int(max(x2 - im.shape[1] + 1, 0))
231
+
232
+ y1_pad = int(max(0, -y1))
233
+ y2_pad = int(max(y2 - im.shape[0] + 1, 0))
234
+
235
+ # Crop target
236
+ im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
237
+ if mask is not None:
238
+ mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
239
+
240
+ # Pad
241
+ im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
242
+ # deal with attention mask
243
+ H, W, _ = im_crop_padded.shape
244
+ att_mask = np.ones((H,W))
245
+ end_x, end_y = -x2_pad, -y2_pad
246
+ if y2_pad == 0:
247
+ end_y = None
248
+ if x2_pad == 0:
249
+ end_x = None
250
+ att_mask[y1_pad:end_y, x1_pad:end_x] = 0
251
+ if mask is not None:
252
+ mask_crop_padded = cv2.copyMakeBorder(mask_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
253
+
254
+ if output_sz is not None:
255
+ resize_factor = output_sz / crop_sz
256
+ im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
257
+ att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
258
+ if mask is None:
259
+ return im_crop_padded, resize_factor, att_mask
260
+ mask_crop_padded = \
261
+ mask_crop_padded = cv2.resize(mask_crop_padded, (output_sz, output_sz))
262
+ return im_crop_padded, resize_factor, att_mask, mask_crop_padded
263
+
264
+ else:
265
+ if mask is None:
266
+ return im_crop_padded, att_mask.astype(np.bool_), 1.0
267
+ return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
268
+
269
+ def clip_box(self, box: list, H, W, margin=0):
270
+ x1, y1, w, h = box
271
+ x2, y2 = x1 + w, y1 + h
272
+ x1 = min(max(0, x1), W-margin)
273
+ x2 = min(max(margin, x2), W)
274
+ y1 = min(max(0, y1), H-margin)
275
+ y2 = min(max(margin, y2), H)
276
+ w = max(margin, x2-x1)
277
+ h = max(margin, y2-y1)
278
+ return [x1, y1, w, h]
279
+
280
+
281
+ def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
282
+ Tracker = MFTrackerORT(model_path = model_path, fp16=False)
283
+ first_frame = True
284
+ Tracker.video_name = frame_path
285
+
286
+ frame_id = 0
287
+ total_time = 0
288
+ for frame in get_frames(Tracker.video_name):
289
+ # print(f"frame shape {frame.shape}")
290
+
291
+ # 如果超过了指定的帧数限制,则跳出循环
292
+ if repeat is not None and frame_id >= repeat:
293
+ print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
294
+ break
295
+
296
+ tic = cv2.getTickCount()
297
+ if first_frame:
298
+ # x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
299
+ x, y, w, h = 1079, 482, 99, 106
300
+
301
+ target_pos = [x, y]
302
+ target_sz = [w, h]
303
+ print('====================type=================', target_pos, type(target_pos), type(target_sz))
304
+ Tracker.track_init(frame, target_pos, target_sz)
305
+ first_frame = False
306
+ else:
307
+ state = Tracker.track(frame)
308
+ frame_id += 1
309
+
310
+ os.makedirs('axmodel_output', exist_ok=True)
311
+ cv2.imwrite(f'axmodel_output/{str(frame_id)}.png', frame)
312
+
313
+ toc = cv2.getTickCount() - tic
314
+ toc = int(1 / (toc / cv2.getTickFrequency()))
315
+ total_time += toc
316
+ print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
317
+
318
+ print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))
319
+
320
+
321
+
322
+ class ExampleParser(argparse.ArgumentParser):
323
+ def error(self, message):
324
+ self.print_usage(sys.stderr)
325
+ print(f"\nError: {message}")
326
+ print("\nExample usage:")
327
+ print(" python3 run_mixformer2_axmodel.py -m <model_file> -f <frame_file>")
328
+ print(" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi")
329
+ print(
330
+ f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
331
+ print(
332
+ f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
333
+ sys.exit(1)
334
+
335
+
336
+ if __name__ == "__main__":
337
+ ap = ExampleParser()
338
+ ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
339
+ ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
340
+ ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
341
+ ap.add_argument(
342
+ '-p',
343
+ '--provider',
344
+ type=str,
345
+ choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"],
346
+ help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"',
347
+ default='AUTO'
348
+ )
349
+ ap.add_argument(
350
+ '-d',
351
+ '--device-id',
352
+ type=int,
353
+ help=R'axclrt device index, depends on how many cards inserted',
354
+ default=0
355
+ )
356
+ args = ap.parse_args()
357
+
358
+ model_file = args.model_path
359
+ frame_file = args.frame_path
360
+
361
+ # check if the model and image exist
362
+ assert os.path.exists(model_file), f"model file path {model_file} does not exist"
363
+ assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"
364
+
365
+ repeat = args.repeat
366
+
367
+ provider = args.provider
368
+ device_id = args.device_id
369
+
370
+ main(model_file, frame_file, repeat, provider, device_id)
run_mixformer2_onnx.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ import time
6
+ import cv2
7
+ import glob
8
+ import onnxruntime
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+ prj_path = os.path.join(os.path.dirname(__file__), '..')
14
+ if prj_path not in sys.path:
15
+ sys.path.append(prj_path)
16
+
17
+ def get_frames(video_name):
18
+ """获取视频帧
19
+
20
+ Args:
21
+ video_name (_type_): _description_
22
+
23
+ Yields:
24
+ _type_: _description_
25
+ """
26
+ if not video_name:
27
+ rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
28
+ cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
29
+
30
+ # warmup
31
+ for i in range(5):
32
+ cap.read()
33
+ while True:
34
+ ret, frame = cap.read()
35
+ if ret:
36
+ # print('读取成功===>>>', frame.shape)
37
+ yield cv2.resize(frame,(800, 600))
38
+ else:
39
+ break
40
+ elif video_name.endswith('avi') or \
41
+ video_name.endswith('mp4'):
42
+ cap = cv2.VideoCapture(video_name)
43
+ while True:
44
+ ret, frame = cap.read()
45
+ if ret:
46
+ yield frame
47
+ else:
48
+ break
49
+ else:
50
+ images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
51
+ for img in images:
52
+ frame = cv2.imread(img)
53
+ yield frame
54
+
55
+ class Preprocessor_wo_mask(object):
56
+ def __init__(self):
57
+ self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1))
58
+ self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1))
59
+
60
+ def process(self, img_arr: np.ndarray):
61
+ # Deal with the image patch
62
+ img_tensor = torch.tensor(img_arr).float().permute((2,0,1)).unsqueeze(dim=0)
63
+ img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
64
+ return img_tensor_norm.contiguous()
65
+
66
+ class MFTrackerORT:
67
+ def __init__(self, model_path, fp16=False) -> None:
68
+ self.debug = True
69
+ self.gpu_id = 0
70
+ self.providers = ["CUDAExecutionProvider"]
71
+ self.provider_options = [{"device_id": str(self.gpu_id)}]
72
+ self.model_path = model_path
73
+ self.fp16 = fp16
74
+
75
+ self.init_track_net()
76
+ self.preprocessor = Preprocessor_wo_mask()
77
+ self.max_score_decay = 1.0
78
+ self.search_factor = 4.5
79
+ self.search_size = 224
80
+ self.template_factor = 2.0
81
+ self.template_size = 112
82
+ self.update_interval = 200
83
+ self.online_size = 1
84
+
85
+ def init_track_net(self):
86
+ """使用设置的参数初始化tracker网络
87
+ """
88
+ self.ort_session = onnxruntime.InferenceSession(self.model_path, providers=self.providers, provider_options=self.provider_options)
89
+
90
+ def track_init(self, frame, target_pos=None, target_sz = None):
91
+ """使用第一帧进行初始化
92
+
93
+ Args:
94
+ frame (_type_): _description_
95
+ target_pos (_type_, optional): _description_. Defaults to None.
96
+ target_sz (_type_, optional): _description_. Defaults to None.
97
+ """
98
+ self.trace_list = []
99
+ try:
100
+ # [x, y, w, h]
101
+ init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
102
+ z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
103
+ template = self.preprocessor.process(z_patch_arr)
104
+ self.template = template
105
+ self.online_template = template
106
+
107
+ self.online_state = init_state
108
+ self.online_image = frame
109
+ self.max_pred_score = -1.0
110
+ self.online_max_template = template
111
+ self.online_forget_id = 0
112
+
113
+ # save states
114
+ self.state = init_state
115
+ self.frame_id = 0
116
+ print(f"第一帧初始化完毕!")
117
+ except:
118
+ print(f"第一帧初始化异常!")
119
+ exit()
120
+
121
+ def track(self, image, info: dict = None):
122
+ H, W, _ = image.shape
123
+ self.frame_id += 1
124
+ x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
125
+ output_sz=self.search_size) # (x1, y1, w, h)
126
+ search = self.preprocessor.process(x_patch_arr)
127
+
128
+ # compute ONNX Runtime output prediction
129
+ ort_inputs = {'img_t': self.to_numpy(self.template), 'img_ot': self.to_numpy(self.online_template), 'img_search': self.to_numpy(search)}
130
+
131
+ ort_outs = self.ort_session.run(None, ort_inputs)
132
+
133
+ # print(f">>> lenght trt_outputs: {ort_outs}")
134
+ pred_boxes = torch.from_numpy(ort_outs[0])
135
+ pred_score = torch.from_numpy(ort_outs[1])
136
+ # print(f">>> box and score: {pred_boxes} {pred_score}")
137
+ # Baseline: Take the mean of all pred boxes as the final result
138
+ pred_box = (pred_boxes.mean(dim=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
139
+ # get the final box result
140
+ self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
141
+
142
+ self.max_pred_score = self.max_pred_score * self.max_score_decay
143
+ # update template
144
+ if pred_score > 0.5 and pred_score > self.max_pred_score:
145
+ z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
146
+ self.template_factor,
147
+ output_sz=self.template_size) # (x1, y1, w, h)
148
+ self.online_max_template = self.preprocessor.process(z_patch_arr)
149
+ self.max_pred_score = pred_score
150
+
151
+
152
+ if self.frame_id % self.update_interval == 0:
153
+ if self.online_size == 1:
154
+ self.online_template = self.online_max_template
155
+ else:
156
+ self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
157
+ self.online_forget_id = (self.online_forget_id + 1) % self.online_size
158
+
159
+ self.max_pred_score = -1
160
+ self.online_max_template = self.template
161
+
162
+ # for debug
163
+ if self.debug:
164
+ x1, y1, w, h = self.state
165
+ # image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
166
+ cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
167
+
168
+ return {"target_bbox": self.state, "conf_score": pred_score}
169
+
170
+ def map_box_back(self, pred_box: list, resize_factor: float):
171
+ cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
172
+ cx, cy, w, h = pred_box
173
+ half_side = 0.5 * self.search_size / resize_factor
174
+ cx_real = cx + (cx_prev - half_side)
175
+ cy_real = cy + (cy_prev - half_side)
176
+ return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
177
+
178
+ def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
179
+ cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
180
+ cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
181
+ half_side = 0.5 * self.search_size / resize_factor
182
+ cx_real = cx + (cx_prev - half_side)
183
+ cy_real = cy + (cy_prev - half_side)
184
+ return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
185
+
186
+ def to_numpy(self, tensor):
187
+ if self.fp16:
188
+ return tensor.detach().cpu().half().numpy() if tensor.requires_grad else tensor.cpu().half().numpy()
189
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
190
+
191
+ def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
192
+ """ Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
193
+
194
+ args:
195
+ im - cv image
196
+ target_bb - target box [x, y, w, h]
197
+ search_area_factor - Ratio of crop size to target size
198
+ output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
199
+
200
+ returns:
201
+ cv image - extracted crop
202
+ float - the factor by which the crop has been resized to make the crop size equal output_size
203
+ """
204
+ if not isinstance(target_bb, list):
205
+ x, y, w, h = target_bb.tolist()
206
+ else:
207
+ x, y, w, h = target_bb
208
+ # Crop image
209
+ crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
210
+
211
+ if crop_sz < 1:
212
+ raise Exception('Too small bounding box.')
213
+
214
+ x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
215
+ x2 = int(x1 + crop_sz)
216
+
217
+ y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
218
+ y2 = int(y1 + crop_sz)
219
+
220
+ x1_pad = int(max(0, -x1))
221
+ x2_pad = int(max(x2 - im.shape[1] + 1, 0))
222
+
223
+ y1_pad = int(max(0, -y1))
224
+ y2_pad = int(max(y2 - im.shape[0] + 1, 0))
225
+
226
+ # Crop target
227
+ im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
228
+ if mask is not None:
229
+ mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
230
+
231
+ # Pad
232
+ im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
233
+ # deal with attention mask
234
+ H, W, _ = im_crop_padded.shape
235
+ att_mask = np.ones((H,W))
236
+ end_x, end_y = -x2_pad, -y2_pad
237
+ if y2_pad == 0:
238
+ end_y = None
239
+ if x2_pad == 0:
240
+ end_x = None
241
+ att_mask[y1_pad:end_y, x1_pad:end_x] = 0
242
+ if mask is not None:
243
+ mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
244
+
245
+
246
+ if output_sz is not None:
247
+ resize_factor = output_sz / crop_sz
248
+ im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
249
+ att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
250
+ if mask is None:
251
+ return im_crop_padded, resize_factor, att_mask
252
+ mask_crop_padded = \
253
+ F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
254
+ return im_crop_padded, resize_factor, att_mask, mask_crop_padded
255
+
256
+ else:
257
+ if mask is None:
258
+ return im_crop_padded, att_mask.astype(np.bool_), 1.0
259
+ return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
260
+
261
+ def clip_box(self, box: list, H, W, margin=0):
262
+ x1, y1, w, h = box
263
+ x2, y2 = x1 + w, y1 + h
264
+ x1 = min(max(0, x1), W-margin)
265
+ x2 = min(max(margin, x2), W)
266
+ y1 = min(max(0, y1), H-margin)
267
+ y2 = min(max(margin, y2), H)
268
+ w = max(margin, x2-x1)
269
+ h = max(margin, y2-y1)
270
+ return [x1, y1, w, h]
271
+
272
+
273
+ def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
274
+ Tracker = MFTrackerORT(model_path = model_path, fp16=False)
275
+ first_frame = True
276
+ Tracker.video_name = frame_path
277
+
278
+ frame_id = 0
279
+ total_time = 0
280
+ for frame in get_frames(Tracker.video_name):
281
+ # print(f"frame shape {frame.shape}")
282
+
283
+ # 如果超过了指定的帧数限制,则跳出循环
284
+ if repeat is not None and frame_id >= repeat:
285
+ print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
286
+ break
287
+
288
+ tic = cv2.getTickCount()
289
+ if first_frame:
290
+ # x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
291
+ # left, top, width, height
292
+ x, y, w, h = 1079, 482, 99, 106
293
+
294
+ target_pos = [x, y]
295
+ target_sz = [w, h]
296
+ print('====================type=================', target_pos, type(target_pos), type(target_sz))
297
+ Tracker.track_init(frame, target_pos, target_sz)
298
+ first_frame = False
299
+ else:
300
+ state = Tracker.track(frame)
301
+ frame_id += 1
302
+
303
+ os.makedirs('onnx_output', exist_ok=True)
304
+ cv2.imwrite(f'onnx_output/{str(frame_id)}.png', frame)
305
+
306
+ toc = cv2.getTickCount() - tic
307
+ toc = int(1 / (toc / cv2.getTickFrequency()))
308
+ total_time += toc
309
+ print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
310
+
311
+ print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))
312
+
313
+
314
+
315
+ class ExampleParser(argparse.ArgumentParser):
316
+ def error(self, message):
317
+ self.print_usage(sys.stderr)
318
+ print(f"\nError: {message}")
319
+ print("\nExample usage:")
320
+ print(" python3 run_mixformer2_onnx.py -m <model_file> -f <frame_file>")
321
+ print(" python3 run_mixformer2_onnx.py -m mixformer_v2_sim.onnx -f car.avi")
322
+ print(
323
+ f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
324
+ print(
325
+ f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
326
+ sys.exit(1)
327
+
328
+
329
+ if __name__ == "__main__":
330
+ ap = ExampleParser()
331
+ ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
332
+ ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
333
+ ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
334
+ ap.add_argument(
335
+ '-p',
336
+ '--provider',
337
+ type=str,
338
+ choices=["AUTO", f"CUDAExecutionProvider", f"CPUExecutionProvider"],
339
+ help=f'"AUTO", "CUDAExecutionProvider", "CPUExecutionProvider"',
340
+ default='AUTO'
341
+ )
342
+ ap.add_argument(
343
+ '-d',
344
+ '--device-id',
345
+ type=int,
346
+ help=R'CUDA device index, depends on how many cards inserted',
347
+ default=0
348
+ )
349
+ args = ap.parse_args()
350
+
351
+ model_file = args.model_path
352
+ frame_file = args.frame_path
353
+
354
+ # check if the model and image exist
355
+ assert os.path.exists(model_file), f"model file path {model_file} does not exist"
356
+ assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"
357
+
358
+ repeat = args.repeat
359
+
360
+ provider = args.provider
361
+ device_id = args.device_id
362
+
363
+ main(model_file, frame_file, repeat, provider, device_id)
364
+