Update eval_onnx.py
#2
by
hangyang-amd
- opened
- eval_onnx.py +2 -1
eval_onnx.py
CHANGED
|
@@ -34,6 +34,7 @@ parser.add_argument(
|
|
| 34 |
default="vaip_config.json",
|
| 35 |
help="Path of the config file for seting provider_options.",
|
| 36 |
)
|
|
|
|
| 37 |
args = parser.parse_args()
|
| 38 |
|
| 39 |
class AverageMeter(object):
|
|
@@ -144,7 +145,7 @@ def val_imagenet():
|
|
| 144 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
| 145 |
with torch.no_grad():
|
| 146 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
| 147 |
-
inputs, targets = images.numpy(), targets
|
| 148 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
| 149 |
|
| 150 |
outputs = ort_session.run(None, ort_inputs)
|
|
|
|
| 34 |
default="vaip_config.json",
|
| 35 |
help="Path of the config file for seting provider_options.",
|
| 36 |
)
|
| 37 |
+
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
|
| 38 |
args = parser.parse_args()
|
| 39 |
|
| 40 |
class AverageMeter(object):
|
|
|
|
| 145 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
| 146 |
with torch.no_grad():
|
| 147 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
| 148 |
+
inputs, targets = images.numpy() if args.data_format == "nchw" else images.permute((0, 2, 3, 1)).numpy(), targets
|
| 149 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
| 150 |
|
| 151 |
outputs = ort_session.run(None, ort_inputs)
|