Commit
·
88ea3d6
1
Parent(s):
21794d5
Update eval_onnx.py (#2)
Browse files- Update eval_onnx.py (08568220b5021edb55619b43a7a55d5d0cbd5ea6)
Co-authored-by: ziheng <[email protected]>
- eval_onnx.py +12 -2
eval_onnx.py
CHANGED
|
@@ -514,18 +514,28 @@ if __name__ == '__main__':
|
|
| 514 |
data_loader = data.getEvalDataloader()
|
| 515 |
# Load MoveNet model using ONNX runtime
|
| 516 |
model = rt.InferenceSession(MODEL_DIR, providers=providers, provider_options=provider_options)
|
| 517 |
-
|
| 518 |
correct = 0
|
| 519 |
total = 0
|
| 520 |
# Loop through the data loader for evaluation
|
| 521 |
for batch_idx, (imgs, labels, kps_mask, img_names) in enumerate(data_loader):
|
|
|
|
| 522 |
if batch_idx%100 == 0:
|
| 523 |
print('Finish ',batch_idx)
|
|
|
|
| 524 |
imgs = imgs.detach().cpu().numpy()
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
pre = movenetDecode(output, kps_mask,mode='output',img_size=IMG_SIZE)
|
| 527 |
gt = movenetDecode(labels, kps_mask,mode='label',img_size=IMG_SIZE)
|
|
|
|
|
|
|
| 528 |
acc = myAcc(pre, gt)
|
|
|
|
| 529 |
correct += sum(acc)
|
| 530 |
total += len(acc)
|
| 531 |
# Compute and print accuracy based on evaluated data
|
|
|
|
| 514 |
data_loader = data.getEvalDataloader()
|
| 515 |
# Load MoveNet model using ONNX runtime
|
| 516 |
model = rt.InferenceSession(MODEL_DIR, providers=providers, provider_options=provider_options)
|
| 517 |
+
|
| 518 |
correct = 0
|
| 519 |
total = 0
|
| 520 |
# Loop through the data loader for evaluation
|
| 521 |
for batch_idx, (imgs, labels, kps_mask, img_names) in enumerate(data_loader):
|
| 522 |
+
|
| 523 |
if batch_idx%100 == 0:
|
| 524 |
print('Finish ',batch_idx)
|
| 525 |
+
|
| 526 |
imgs = imgs.detach().cpu().numpy()
|
| 527 |
+
imgs = imgs.transpose((0,2,3,1))
|
| 528 |
+
output = model.run(['1548_transpose','1607_transpose','1665_transpose','1723_transpose'],{'blob.1':imgs})
|
| 529 |
+
output[0] = output[0].transpose((0,3,1,2))
|
| 530 |
+
output[1] = output[1].transpose((0,3,1,2))
|
| 531 |
+
output[2] = output[2].transpose((0,3,1,2))
|
| 532 |
+
output[3] = output[3].transpose((0,3,1,2))
|
| 533 |
pre = movenetDecode(output, kps_mask,mode='output',img_size=IMG_SIZE)
|
| 534 |
gt = movenetDecode(labels, kps_mask,mode='label',img_size=IMG_SIZE)
|
| 535 |
+
|
| 536 |
+
#n
|
| 537 |
acc = myAcc(pre, gt)
|
| 538 |
+
|
| 539 |
correct += sum(acc)
|
| 540 |
total += len(acc)
|
| 541 |
# Compute and print accuracy based on evaluated data
|