Upload 6 files
Browse files- .gitattributes +1 -0
- ax650/mixformer_v2.axmodel +3 -0
- car.avi +3 -0
- config.json +0 -0
- onnx/mixformer_v2_sim.onnx +3 -0
- run_mixformer2_axmodel.py +370 -0
- run_mixformer2_onnx.py +364 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
car.avi filter=lfs diff=lfs merge=lfs -text
|
ax650/mixformer_v2.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ae9b045601ccb716a7bcec3a72a601e279db421a98de779f0283d750694f84d
|
| 3 |
+
size 23041126
|
car.avi
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12085e86f5791f9936f382027d51ed02e2485636a03480aeaec2ac03e2bd2bd1
|
| 3 |
+
size 65474094
|
config.json
ADDED
|
File without changes
|
onnx/mixformer_v2_sim.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9863932c9d18e6c75adcbec1871d293440bd350bc08f76707ca8088312e2175c
|
| 3 |
+
size 64950790
|
run_mixformer2_axmodel.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import cv2
|
| 7 |
+
import math
|
| 8 |
+
import glob
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import axengine as axe
|
| 12 |
+
from axengine import axclrt_provider_name, axengine_provider_name
|
| 13 |
+
|
| 14 |
+
def load_model(model_path: str | os.PathLike, selected_provider: str, selected_device_id: int = 0):
|
| 15 |
+
if selected_provider == 'AUTO':
|
| 16 |
+
# Use AUTO to let the pyengine choose the first available provider
|
| 17 |
+
return axe.InferenceSession(model_path)
|
| 18 |
+
|
| 19 |
+
providers = []
|
| 20 |
+
if selected_provider == axclrt_provider_name:
|
| 21 |
+
provider_options = {"device_id": selected_device_id}
|
| 22 |
+
providers.append((axclrt_provider_name, provider_options))
|
| 23 |
+
if selected_provider == axengine_provider_name:
|
| 24 |
+
providers.append(axengine_provider_name)
|
| 25 |
+
|
| 26 |
+
return axe.InferenceSession(model_path, providers=providers)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_frames(video_name):
|
| 30 |
+
"""获取视频帧
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
video_name (_type_): _description_
|
| 34 |
+
|
| 35 |
+
Yields:
|
| 36 |
+
_type_: _description_
|
| 37 |
+
"""
|
| 38 |
+
if not video_name:
|
| 39 |
+
rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
|
| 40 |
+
cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
|
| 41 |
+
|
| 42 |
+
# warmup
|
| 43 |
+
for i in range(5):
|
| 44 |
+
cap.read()
|
| 45 |
+
while True:
|
| 46 |
+
ret, frame = cap.read()
|
| 47 |
+
if ret:
|
| 48 |
+
# print('读取成功===>>>', frame.shape)
|
| 49 |
+
yield cv2.resize(frame,(800, 600))
|
| 50 |
+
else:
|
| 51 |
+
break
|
| 52 |
+
elif video_name.endswith('avi') or \
|
| 53 |
+
video_name.endswith('mp4'):
|
| 54 |
+
cap = cv2.VideoCapture(video_name)
|
| 55 |
+
while True:
|
| 56 |
+
ret, frame = cap.read()
|
| 57 |
+
if ret:
|
| 58 |
+
yield frame
|
| 59 |
+
else:
|
| 60 |
+
break
|
| 61 |
+
else:
|
| 62 |
+
images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
|
| 63 |
+
for img in images:
|
| 64 |
+
frame = cv2.imread(img)
|
| 65 |
+
yield frame
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Preprocessor_wo_mask(object):
|
| 69 |
+
def __init__(self):
|
| 70 |
+
self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)).astype(np.float32)
|
| 71 |
+
self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)).astype(np.float32)
|
| 72 |
+
|
| 73 |
+
def process(self, img_arr: np.ndarray):
|
| 74 |
+
# Deal with the image patch
|
| 75 |
+
img_tensor = img_arr.transpose((2, 0, 1)).reshape((1, 3, img_arr.shape[0], img_arr.shape[1])).astype(np.float32) / 255.0
|
| 76 |
+
img_tensor_norm = (img_tensor - self.mean) / self.std # (1,3,H,W)
|
| 77 |
+
return img_tensor_norm
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MFTrackerORT:
|
| 81 |
+
def __init__(self, model_path, fp16=False) -> None:
|
| 82 |
+
self.debug = True
|
| 83 |
+
self.gpu_id = 0
|
| 84 |
+
self.providers = ["CUDAExecutionProvider"]
|
| 85 |
+
self.provider_options = [{"device_id": str(self.gpu_id)}]
|
| 86 |
+
self.model_path = model_path
|
| 87 |
+
self.fp16 = fp16
|
| 88 |
+
|
| 89 |
+
self.init_track_net()
|
| 90 |
+
self.preprocessor = Preprocessor_wo_mask()
|
| 91 |
+
self.max_score_decay = 1.0
|
| 92 |
+
self.search_factor = 4.5
|
| 93 |
+
self.search_size = 224
|
| 94 |
+
self.template_factor = 2.0
|
| 95 |
+
self.template_size = 112
|
| 96 |
+
self.update_interval = 200
|
| 97 |
+
self.online_size = 1
|
| 98 |
+
|
| 99 |
+
def init_track_net(self):
|
| 100 |
+
"""使用设置的参数初始化tracker网络
|
| 101 |
+
"""
|
| 102 |
+
self.ax_session = load_model(self.model_path, selected_provider="AUTO")
|
| 103 |
+
|
| 104 |
+
def track_init(self, frame, target_pos=None, target_sz = None):
|
| 105 |
+
"""使用第一帧进行初始化
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
frame (_type_): _description_
|
| 109 |
+
target_pos (_type_, optional): _description_. Defaults to None.
|
| 110 |
+
target_sz (_type_, optional): _description_. Defaults to None.
|
| 111 |
+
"""
|
| 112 |
+
self.trace_list = []
|
| 113 |
+
try:
|
| 114 |
+
# [x, y, w, h]
|
| 115 |
+
init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
|
| 116 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
|
| 117 |
+
template = self.preprocessor.process(z_patch_arr)
|
| 118 |
+
self.template = template
|
| 119 |
+
self.online_template = template
|
| 120 |
+
|
| 121 |
+
self.online_state = init_state
|
| 122 |
+
self.online_image = frame
|
| 123 |
+
self.max_pred_score = -1.0
|
| 124 |
+
self.online_max_template = template
|
| 125 |
+
self.online_forget_id = 0
|
| 126 |
+
|
| 127 |
+
# save states
|
| 128 |
+
self.state = init_state
|
| 129 |
+
self.frame_id = 0
|
| 130 |
+
print(f"第一帧初始化完毕!")
|
| 131 |
+
except:
|
| 132 |
+
print(f"第一帧初始化异常!")
|
| 133 |
+
exit()
|
| 134 |
+
|
| 135 |
+
def track(self, image, info: dict = None):
|
| 136 |
+
H, W, _ = image.shape
|
| 137 |
+
self.frame_id += 1
|
| 138 |
+
x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
|
| 139 |
+
output_sz=self.search_size) # (x1, y1, w, h)
|
| 140 |
+
search = self.preprocessor.process(x_patch_arr)
|
| 141 |
+
|
| 142 |
+
# compute ONNX Runtime output prediction
|
| 143 |
+
ort_inputs = {'img_t': self.template, 'img_ot': self.online_template, 'img_search': search}
|
| 144 |
+
|
| 145 |
+
ort_outs = self.ax_session.run(None, ort_inputs)
|
| 146 |
+
|
| 147 |
+
# print(f">>> lenght trt_outputs: {ort_outs}")
|
| 148 |
+
pred_boxes = ort_outs[0]
|
| 149 |
+
pred_score = ort_outs[1]
|
| 150 |
+
# print(f">>> box and score: {pred_boxes} {pred_score}")
|
| 151 |
+
# Baseline: Take the mean of all pred boxes as the final result
|
| 152 |
+
pred_box = (np.mean(pred_boxes, axis=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
| 153 |
+
# get the final box result
|
| 154 |
+
self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
|
| 155 |
+
|
| 156 |
+
self.max_pred_score = self.max_pred_score * self.max_score_decay
|
| 157 |
+
# update template
|
| 158 |
+
if pred_score > 0.5 and pred_score > self.max_pred_score:
|
| 159 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
|
| 160 |
+
self.template_factor,
|
| 161 |
+
output_sz=self.template_size) # (x1, y1, w, h)
|
| 162 |
+
self.online_max_template = self.preprocessor.process(z_patch_arr)
|
| 163 |
+
self.max_pred_score = pred_score
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if self.frame_id % self.update_interval == 0:
|
| 167 |
+
if self.online_size == 1:
|
| 168 |
+
self.online_template = self.online_max_template
|
| 169 |
+
else:
|
| 170 |
+
self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
|
| 171 |
+
self.online_forget_id = (self.online_forget_id + 1) % self.online_size
|
| 172 |
+
|
| 173 |
+
self.max_pred_score = -1
|
| 174 |
+
self.online_max_template = self.template
|
| 175 |
+
|
| 176 |
+
# for debug
|
| 177 |
+
if self.debug:
|
| 178 |
+
x1, y1, w, h = self.state
|
| 179 |
+
# image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 180 |
+
cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
|
| 181 |
+
|
| 182 |
+
return {"target_bbox": self.state, "conf_score": pred_score}
|
| 183 |
+
|
| 184 |
+
def map_box_back(self, pred_box: list, resize_factor: float):
|
| 185 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
| 186 |
+
cx, cy, w, h = pred_box
|
| 187 |
+
half_side = 0.5 * self.search_size / resize_factor
|
| 188 |
+
cx_real = cx + (cx_prev - half_side)
|
| 189 |
+
cy_real = cy + (cy_prev - half_side)
|
| 190 |
+
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
| 191 |
+
|
| 192 |
+
def map_box_back_batch(self, pred_box: np.ndarray, resize_factor: float):
|
| 193 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
| 194 |
+
cx, cy, w, h = pred_box.T # (N,4) --> (N,)
|
| 195 |
+
half_side = 0.5 * self.search_size / resize_factor
|
| 196 |
+
cx_real = cx + (cx_prev - half_side)
|
| 197 |
+
cy_real = cy + (cy_prev - half_side)
|
| 198 |
+
return np.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], axis=-1)
|
| 199 |
+
|
| 200 |
+
def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
|
| 201 |
+
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
| 202 |
+
|
| 203 |
+
args:
|
| 204 |
+
im - cv image
|
| 205 |
+
target_bb - target box [x, y, w, h]
|
| 206 |
+
search_area_factor - Ratio of crop size to target size
|
| 207 |
+
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
| 208 |
+
|
| 209 |
+
returns:
|
| 210 |
+
cv image - extracted crop
|
| 211 |
+
float - the factor by which the crop has been resized to make the crop size equal output_size
|
| 212 |
+
"""
|
| 213 |
+
if not isinstance(target_bb, list):
|
| 214 |
+
x, y, w, h = target_bb.tolist()
|
| 215 |
+
else:
|
| 216 |
+
x, y, w, h = target_bb
|
| 217 |
+
# Crop image
|
| 218 |
+
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
| 219 |
+
|
| 220 |
+
if crop_sz < 1:
|
| 221 |
+
raise Exception('Too small bounding box.')
|
| 222 |
+
|
| 223 |
+
x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
|
| 224 |
+
x2 = int(x1 + crop_sz)
|
| 225 |
+
|
| 226 |
+
y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
|
| 227 |
+
y2 = int(y1 + crop_sz)
|
| 228 |
+
|
| 229 |
+
x1_pad = int(max(0, -x1))
|
| 230 |
+
x2_pad = int(max(x2 - im.shape[1] + 1, 0))
|
| 231 |
+
|
| 232 |
+
y1_pad = int(max(0, -y1))
|
| 233 |
+
y2_pad = int(max(y2 - im.shape[0] + 1, 0))
|
| 234 |
+
|
| 235 |
+
# Crop target
|
| 236 |
+
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
| 237 |
+
if mask is not None:
|
| 238 |
+
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
| 239 |
+
|
| 240 |
+
# Pad
|
| 241 |
+
im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
|
| 242 |
+
# deal with attention mask
|
| 243 |
+
H, W, _ = im_crop_padded.shape
|
| 244 |
+
att_mask = np.ones((H,W))
|
| 245 |
+
end_x, end_y = -x2_pad, -y2_pad
|
| 246 |
+
if y2_pad == 0:
|
| 247 |
+
end_y = None
|
| 248 |
+
if x2_pad == 0:
|
| 249 |
+
end_x = None
|
| 250 |
+
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
| 251 |
+
if mask is not None:
|
| 252 |
+
mask_crop_padded = cv2.copyMakeBorder(mask_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
|
| 253 |
+
|
| 254 |
+
if output_sz is not None:
|
| 255 |
+
resize_factor = output_sz / crop_sz
|
| 256 |
+
im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
|
| 257 |
+
att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
| 258 |
+
if mask is None:
|
| 259 |
+
return im_crop_padded, resize_factor, att_mask
|
| 260 |
+
mask_crop_padded = \
|
| 261 |
+
mask_crop_padded = cv2.resize(mask_crop_padded, (output_sz, output_sz))
|
| 262 |
+
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
if mask is None:
|
| 266 |
+
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
| 267 |
+
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
| 268 |
+
|
| 269 |
+
def clip_box(self, box: list, H, W, margin=0):
|
| 270 |
+
x1, y1, w, h = box
|
| 271 |
+
x2, y2 = x1 + w, y1 + h
|
| 272 |
+
x1 = min(max(0, x1), W-margin)
|
| 273 |
+
x2 = min(max(margin, x2), W)
|
| 274 |
+
y1 = min(max(0, y1), H-margin)
|
| 275 |
+
y2 = min(max(margin, y2), H)
|
| 276 |
+
w = max(margin, x2-x1)
|
| 277 |
+
h = max(margin, y2-y1)
|
| 278 |
+
return [x1, y1, w, h]
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
|
| 282 |
+
Tracker = MFTrackerORT(model_path = model_path, fp16=False)
|
| 283 |
+
first_frame = True
|
| 284 |
+
Tracker.video_name = frame_path
|
| 285 |
+
|
| 286 |
+
frame_id = 0
|
| 287 |
+
total_time = 0
|
| 288 |
+
for frame in get_frames(Tracker.video_name):
|
| 289 |
+
# print(f"frame shape {frame.shape}")
|
| 290 |
+
|
| 291 |
+
# 如果超过了指定的帧数限制,则跳出循环
|
| 292 |
+
if repeat is not None and frame_id >= repeat:
|
| 293 |
+
print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
tic = cv2.getTickCount()
|
| 297 |
+
if first_frame:
|
| 298 |
+
# x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
|
| 299 |
+
x, y, w, h = 1079, 482, 99, 106
|
| 300 |
+
|
| 301 |
+
target_pos = [x, y]
|
| 302 |
+
target_sz = [w, h]
|
| 303 |
+
print('====================type=================', target_pos, type(target_pos), type(target_sz))
|
| 304 |
+
Tracker.track_init(frame, target_pos, target_sz)
|
| 305 |
+
first_frame = False
|
| 306 |
+
else:
|
| 307 |
+
state = Tracker.track(frame)
|
| 308 |
+
frame_id += 1
|
| 309 |
+
|
| 310 |
+
os.makedirs('axmodel_output', exist_ok=True)
|
| 311 |
+
cv2.imwrite(f'axmodel_output/{str(frame_id)}.png', frame)
|
| 312 |
+
|
| 313 |
+
toc = cv2.getTickCount() - tic
|
| 314 |
+
toc = int(1 / (toc / cv2.getTickFrequency()))
|
| 315 |
+
total_time += toc
|
| 316 |
+
print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
|
| 317 |
+
|
| 318 |
+
print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class ExampleParser(argparse.ArgumentParser):
|
| 323 |
+
def error(self, message):
|
| 324 |
+
self.print_usage(sys.stderr)
|
| 325 |
+
print(f"\nError: {message}")
|
| 326 |
+
print("\nExample usage:")
|
| 327 |
+
print(" python3 run_mixformer2_axmodel.py -m <model_file> -f <frame_file>")
|
| 328 |
+
print(" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi")
|
| 329 |
+
print(
|
| 330 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
|
| 331 |
+
print(
|
| 332 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
|
| 333 |
+
sys.exit(1)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
ap = ExampleParser()
|
| 338 |
+
ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
|
| 339 |
+
ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
|
| 340 |
+
ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
|
| 341 |
+
ap.add_argument(
|
| 342 |
+
'-p',
|
| 343 |
+
'--provider',
|
| 344 |
+
type=str,
|
| 345 |
+
choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"],
|
| 346 |
+
help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"',
|
| 347 |
+
default='AUTO'
|
| 348 |
+
)
|
| 349 |
+
ap.add_argument(
|
| 350 |
+
'-d',
|
| 351 |
+
'--device-id',
|
| 352 |
+
type=int,
|
| 353 |
+
help=R'axclrt device index, depends on how many cards inserted',
|
| 354 |
+
default=0
|
| 355 |
+
)
|
| 356 |
+
args = ap.parse_args()
|
| 357 |
+
|
| 358 |
+
model_file = args.model_path
|
| 359 |
+
frame_file = args.frame_path
|
| 360 |
+
|
| 361 |
+
# check if the model and image exist
|
| 362 |
+
assert os.path.exists(model_file), f"model file path {model_file} does not exist"
|
| 363 |
+
assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"
|
| 364 |
+
|
| 365 |
+
repeat = args.repeat
|
| 366 |
+
|
| 367 |
+
provider = args.provider
|
| 368 |
+
device_id = args.device_id
|
| 369 |
+
|
| 370 |
+
main(model_file, frame_file, repeat, provider, device_id)
|
run_mixformer2_onnx.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import cv2
|
| 7 |
+
import glob
|
| 8 |
+
import onnxruntime
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
prj_path = os.path.join(os.path.dirname(__file__), '..')
|
| 14 |
+
if prj_path not in sys.path:
|
| 15 |
+
sys.path.append(prj_path)
|
| 16 |
+
|
| 17 |
+
def get_frames(video_name):
|
| 18 |
+
"""获取视频帧
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
video_name (_type_): _description_
|
| 22 |
+
|
| 23 |
+
Yields:
|
| 24 |
+
_type_: _description_
|
| 25 |
+
"""
|
| 26 |
+
if not video_name:
|
| 27 |
+
rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
|
| 28 |
+
cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
|
| 29 |
+
|
| 30 |
+
# warmup
|
| 31 |
+
for i in range(5):
|
| 32 |
+
cap.read()
|
| 33 |
+
while True:
|
| 34 |
+
ret, frame = cap.read()
|
| 35 |
+
if ret:
|
| 36 |
+
# print('读取成功===>>>', frame.shape)
|
| 37 |
+
yield cv2.resize(frame,(800, 600))
|
| 38 |
+
else:
|
| 39 |
+
break
|
| 40 |
+
elif video_name.endswith('avi') or \
|
| 41 |
+
video_name.endswith('mp4'):
|
| 42 |
+
cap = cv2.VideoCapture(video_name)
|
| 43 |
+
while True:
|
| 44 |
+
ret, frame = cap.read()
|
| 45 |
+
if ret:
|
| 46 |
+
yield frame
|
| 47 |
+
else:
|
| 48 |
+
break
|
| 49 |
+
else:
|
| 50 |
+
images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
|
| 51 |
+
for img in images:
|
| 52 |
+
frame = cv2.imread(img)
|
| 53 |
+
yield frame
|
| 54 |
+
|
| 55 |
+
class Preprocessor_wo_mask(object):
|
| 56 |
+
def __init__(self):
|
| 57 |
+
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1))
|
| 58 |
+
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1))
|
| 59 |
+
|
| 60 |
+
def process(self, img_arr: np.ndarray):
|
| 61 |
+
# Deal with the image patch
|
| 62 |
+
img_tensor = torch.tensor(img_arr).float().permute((2,0,1)).unsqueeze(dim=0)
|
| 63 |
+
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
|
| 64 |
+
return img_tensor_norm.contiguous()
|
| 65 |
+
|
| 66 |
+
class MFTrackerORT:
|
| 67 |
+
def __init__(self, model_path, fp16=False) -> None:
|
| 68 |
+
self.debug = True
|
| 69 |
+
self.gpu_id = 0
|
| 70 |
+
self.providers = ["CUDAExecutionProvider"]
|
| 71 |
+
self.provider_options = [{"device_id": str(self.gpu_id)}]
|
| 72 |
+
self.model_path = model_path
|
| 73 |
+
self.fp16 = fp16
|
| 74 |
+
|
| 75 |
+
self.init_track_net()
|
| 76 |
+
self.preprocessor = Preprocessor_wo_mask()
|
| 77 |
+
self.max_score_decay = 1.0
|
| 78 |
+
self.search_factor = 4.5
|
| 79 |
+
self.search_size = 224
|
| 80 |
+
self.template_factor = 2.0
|
| 81 |
+
self.template_size = 112
|
| 82 |
+
self.update_interval = 200
|
| 83 |
+
self.online_size = 1
|
| 84 |
+
|
| 85 |
+
def init_track_net(self):
|
| 86 |
+
"""使用设置的参数初始化tracker网络
|
| 87 |
+
"""
|
| 88 |
+
self.ort_session = onnxruntime.InferenceSession(self.model_path, providers=self.providers, provider_options=self.provider_options)
|
| 89 |
+
|
| 90 |
+
def track_init(self, frame, target_pos=None, target_sz = None):
|
| 91 |
+
"""使用第一帧进行初始化
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
frame (_type_): _description_
|
| 95 |
+
target_pos (_type_, optional): _description_. Defaults to None.
|
| 96 |
+
target_sz (_type_, optional): _description_. Defaults to None.
|
| 97 |
+
"""
|
| 98 |
+
self.trace_list = []
|
| 99 |
+
try:
|
| 100 |
+
# [x, y, w, h]
|
| 101 |
+
init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
|
| 102 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
|
| 103 |
+
template = self.preprocessor.process(z_patch_arr)
|
| 104 |
+
self.template = template
|
| 105 |
+
self.online_template = template
|
| 106 |
+
|
| 107 |
+
self.online_state = init_state
|
| 108 |
+
self.online_image = frame
|
| 109 |
+
self.max_pred_score = -1.0
|
| 110 |
+
self.online_max_template = template
|
| 111 |
+
self.online_forget_id = 0
|
| 112 |
+
|
| 113 |
+
# save states
|
| 114 |
+
self.state = init_state
|
| 115 |
+
self.frame_id = 0
|
| 116 |
+
print(f"第一帧初始化完毕!")
|
| 117 |
+
except:
|
| 118 |
+
print(f"第一帧初始化异常!")
|
| 119 |
+
exit()
|
| 120 |
+
|
| 121 |
+
def track(self, image, info: dict = None):
|
| 122 |
+
H, W, _ = image.shape
|
| 123 |
+
self.frame_id += 1
|
| 124 |
+
x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
|
| 125 |
+
output_sz=self.search_size) # (x1, y1, w, h)
|
| 126 |
+
search = self.preprocessor.process(x_patch_arr)
|
| 127 |
+
|
| 128 |
+
# compute ONNX Runtime output prediction
|
| 129 |
+
ort_inputs = {'img_t': self.to_numpy(self.template), 'img_ot': self.to_numpy(self.online_template), 'img_search': self.to_numpy(search)}
|
| 130 |
+
|
| 131 |
+
ort_outs = self.ort_session.run(None, ort_inputs)
|
| 132 |
+
|
| 133 |
+
# print(f">>> lenght trt_outputs: {ort_outs}")
|
| 134 |
+
pred_boxes = torch.from_numpy(ort_outs[0])
|
| 135 |
+
pred_score = torch.from_numpy(ort_outs[1])
|
| 136 |
+
# print(f">>> box and score: {pred_boxes} {pred_score}")
|
| 137 |
+
# Baseline: Take the mean of all pred boxes as the final result
|
| 138 |
+
pred_box = (pred_boxes.mean(dim=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
| 139 |
+
# get the final box result
|
| 140 |
+
self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
|
| 141 |
+
|
| 142 |
+
self.max_pred_score = self.max_pred_score * self.max_score_decay
|
| 143 |
+
# update template
|
| 144 |
+
if pred_score > 0.5 and pred_score > self.max_pred_score:
|
| 145 |
+
z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
|
| 146 |
+
self.template_factor,
|
| 147 |
+
output_sz=self.template_size) # (x1, y1, w, h)
|
| 148 |
+
self.online_max_template = self.preprocessor.process(z_patch_arr)
|
| 149 |
+
self.max_pred_score = pred_score
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if self.frame_id % self.update_interval == 0:
|
| 153 |
+
if self.online_size == 1:
|
| 154 |
+
self.online_template = self.online_max_template
|
| 155 |
+
else:
|
| 156 |
+
self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
|
| 157 |
+
self.online_forget_id = (self.online_forget_id + 1) % self.online_size
|
| 158 |
+
|
| 159 |
+
self.max_pred_score = -1
|
| 160 |
+
self.online_max_template = self.template
|
| 161 |
+
|
| 162 |
+
# for debug
|
| 163 |
+
if self.debug:
|
| 164 |
+
x1, y1, w, h = self.state
|
| 165 |
+
# image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 166 |
+
cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
|
| 167 |
+
|
| 168 |
+
return {"target_bbox": self.state, "conf_score": pred_score}
|
| 169 |
+
|
| 170 |
+
def map_box_back(self, pred_box: list, resize_factor: float):
|
| 171 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
| 172 |
+
cx, cy, w, h = pred_box
|
| 173 |
+
half_side = 0.5 * self.search_size / resize_factor
|
| 174 |
+
cx_real = cx + (cx_prev - half_side)
|
| 175 |
+
cy_real = cy + (cy_prev - half_side)
|
| 176 |
+
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
| 177 |
+
|
| 178 |
+
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
|
| 179 |
+
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
| 180 |
+
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
|
| 181 |
+
half_side = 0.5 * self.search_size / resize_factor
|
| 182 |
+
cx_real = cx + (cx_prev - half_side)
|
| 183 |
+
cy_real = cy + (cy_prev - half_side)
|
| 184 |
+
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
|
| 185 |
+
|
| 186 |
+
def to_numpy(self, tensor):
|
| 187 |
+
if self.fp16:
|
| 188 |
+
return tensor.detach().cpu().half().numpy() if tensor.requires_grad else tensor.cpu().half().numpy()
|
| 189 |
+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
| 190 |
+
|
| 191 |
+
def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
|
| 192 |
+
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
| 193 |
+
|
| 194 |
+
args:
|
| 195 |
+
im - cv image
|
| 196 |
+
target_bb - target box [x, y, w, h]
|
| 197 |
+
search_area_factor - Ratio of crop size to target size
|
| 198 |
+
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
| 199 |
+
|
| 200 |
+
returns:
|
| 201 |
+
cv image - extracted crop
|
| 202 |
+
float - the factor by which the crop has been resized to make the crop size equal output_size
|
| 203 |
+
"""
|
| 204 |
+
if not isinstance(target_bb, list):
|
| 205 |
+
x, y, w, h = target_bb.tolist()
|
| 206 |
+
else:
|
| 207 |
+
x, y, w, h = target_bb
|
| 208 |
+
# Crop image
|
| 209 |
+
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
| 210 |
+
|
| 211 |
+
if crop_sz < 1:
|
| 212 |
+
raise Exception('Too small bounding box.')
|
| 213 |
+
|
| 214 |
+
x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
|
| 215 |
+
x2 = int(x1 + crop_sz)
|
| 216 |
+
|
| 217 |
+
y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
|
| 218 |
+
y2 = int(y1 + crop_sz)
|
| 219 |
+
|
| 220 |
+
x1_pad = int(max(0, -x1))
|
| 221 |
+
x2_pad = int(max(x2 - im.shape[1] + 1, 0))
|
| 222 |
+
|
| 223 |
+
y1_pad = int(max(0, -y1))
|
| 224 |
+
y2_pad = int(max(y2 - im.shape[0] + 1, 0))
|
| 225 |
+
|
| 226 |
+
# Crop target
|
| 227 |
+
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
| 228 |
+
if mask is not None:
|
| 229 |
+
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
| 230 |
+
|
| 231 |
+
# Pad
|
| 232 |
+
im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
|
| 233 |
+
# deal with attention mask
|
| 234 |
+
H, W, _ = im_crop_padded.shape
|
| 235 |
+
att_mask = np.ones((H,W))
|
| 236 |
+
end_x, end_y = -x2_pad, -y2_pad
|
| 237 |
+
if y2_pad == 0:
|
| 238 |
+
end_y = None
|
| 239 |
+
if x2_pad == 0:
|
| 240 |
+
end_x = None
|
| 241 |
+
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
| 242 |
+
if mask is not None:
|
| 243 |
+
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
if output_sz is not None:
|
| 247 |
+
resize_factor = output_sz / crop_sz
|
| 248 |
+
im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
|
| 249 |
+
att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
| 250 |
+
if mask is None:
|
| 251 |
+
return im_crop_padded, resize_factor, att_mask
|
| 252 |
+
mask_crop_padded = \
|
| 253 |
+
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
|
| 254 |
+
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
if mask is None:
|
| 258 |
+
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
| 259 |
+
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
| 260 |
+
|
| 261 |
+
def clip_box(self, box: list, H, W, margin=0):
|
| 262 |
+
x1, y1, w, h = box
|
| 263 |
+
x2, y2 = x1 + w, y1 + h
|
| 264 |
+
x1 = min(max(0, x1), W-margin)
|
| 265 |
+
x2 = min(max(margin, x2), W)
|
| 266 |
+
y1 = min(max(0, y1), H-margin)
|
| 267 |
+
y2 = min(max(margin, y2), H)
|
| 268 |
+
w = max(margin, x2-x1)
|
| 269 |
+
h = max(margin, y2-y1)
|
| 270 |
+
return [x1, y1, w, h]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
|
| 274 |
+
Tracker = MFTrackerORT(model_path = model_path, fp16=False)
|
| 275 |
+
first_frame = True
|
| 276 |
+
Tracker.video_name = frame_path
|
| 277 |
+
|
| 278 |
+
frame_id = 0
|
| 279 |
+
total_time = 0
|
| 280 |
+
for frame in get_frames(Tracker.video_name):
|
| 281 |
+
# print(f"frame shape {frame.shape}")
|
| 282 |
+
|
| 283 |
+
# 如果超过了指定的帧数限制,则跳出循环
|
| 284 |
+
if repeat is not None and frame_id >= repeat:
|
| 285 |
+
print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
|
| 286 |
+
break
|
| 287 |
+
|
| 288 |
+
tic = cv2.getTickCount()
|
| 289 |
+
if first_frame:
|
| 290 |
+
# x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
|
| 291 |
+
# left, top, width, height
|
| 292 |
+
x, y, w, h = 1079, 482, 99, 106
|
| 293 |
+
|
| 294 |
+
target_pos = [x, y]
|
| 295 |
+
target_sz = [w, h]
|
| 296 |
+
print('====================type=================', target_pos, type(target_pos), type(target_sz))
|
| 297 |
+
Tracker.track_init(frame, target_pos, target_sz)
|
| 298 |
+
first_frame = False
|
| 299 |
+
else:
|
| 300 |
+
state = Tracker.track(frame)
|
| 301 |
+
frame_id += 1
|
| 302 |
+
|
| 303 |
+
os.makedirs('onnx_output', exist_ok=True)
|
| 304 |
+
cv2.imwrite(f'onnx_output/{str(frame_id)}.png', frame)
|
| 305 |
+
|
| 306 |
+
toc = cv2.getTickCount() - tic
|
| 307 |
+
toc = int(1 / (toc / cv2.getTickFrequency()))
|
| 308 |
+
total_time += toc
|
| 309 |
+
print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
|
| 310 |
+
|
| 311 |
+
print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class ExampleParser(argparse.ArgumentParser):
|
| 316 |
+
def error(self, message):
|
| 317 |
+
self.print_usage(sys.stderr)
|
| 318 |
+
print(f"\nError: {message}")
|
| 319 |
+
print("\nExample usage:")
|
| 320 |
+
print(" python3 run_mixformer2_onnx.py -m <model_file> -f <frame_file>")
|
| 321 |
+
print(" python3 run_mixformer2_onnx.py -m mixformer_v2_sim.onnx -f car.avi")
|
| 322 |
+
print(
|
| 323 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
|
| 324 |
+
print(
|
| 325 |
+
f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
|
| 326 |
+
sys.exit(1)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
ap = ExampleParser()
|
| 331 |
+
ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
|
| 332 |
+
ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
|
| 333 |
+
ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
|
| 334 |
+
ap.add_argument(
|
| 335 |
+
'-p',
|
| 336 |
+
'--provider',
|
| 337 |
+
type=str,
|
| 338 |
+
choices=["AUTO", f"CUDAExecutionProvider", f"CPUExecutionProvider"],
|
| 339 |
+
help=f'"AUTO", "CUDAExecutionProvider", "CPUExecutionProvider"',
|
| 340 |
+
default='AUTO'
|
| 341 |
+
)
|
| 342 |
+
ap.add_argument(
|
| 343 |
+
'-d',
|
| 344 |
+
'--device-id',
|
| 345 |
+
type=int,
|
| 346 |
+
help=R'CUDA device index, depends on how many cards inserted',
|
| 347 |
+
default=0
|
| 348 |
+
)
|
| 349 |
+
args = ap.parse_args()
|
| 350 |
+
|
| 351 |
+
model_file = args.model_path
|
| 352 |
+
frame_file = args.frame_path
|
| 353 |
+
|
| 354 |
+
# check if the model and image exist
|
| 355 |
+
assert os.path.exists(model_file), f"model file path {model_file} does not exist"
|
| 356 |
+
assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"
|
| 357 |
+
|
| 358 |
+
repeat = args.repeat
|
| 359 |
+
|
| 360 |
+
provider = args.provider
|
| 361 |
+
device_id = args.device_id
|
| 362 |
+
|
| 363 |
+
main(model_file, frame_file, repeat, provider, device_id)
|
| 364 |
+
|