Upload 2 files
Browse files- ResNet_int_NHWC.onnx +3 -0
- eval_onnx.py +1 -4
ResNet_int_NHWC.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7f27d4fb250334533478861df64a192e141d4c7877ec465ce1cafd241a7dcf7
|
| 3 |
+
size 102275699
|
eval_onnx.py
CHANGED
|
@@ -68,7 +68,6 @@ def accuracy(output: torch.Tensor,
|
|
| 68 |
output: Prediction of the model.
|
| 69 |
target: Ground truth labels.
|
| 70 |
topk: Topk accuracy to compute.
|
| 71 |
-
|
| 72 |
Returns:
|
| 73 |
Accuracy results according to 'topk'.
|
| 74 |
"""
|
|
@@ -91,13 +90,11 @@ def prepare_data_loader(data_dir: str,
|
|
| 91 |
batch_size: int = 100,
|
| 92 |
workers: int = 8) -> torch.utils.data.DataLoader:
|
| 93 |
"""Returns a validation data loader of ImageNet by given `data_dir`.
|
| 94 |
-
|
| 95 |
Args:
|
| 96 |
data_dir: Directory where images stores. There must be a subdirectory named
|
| 97 |
'validation' that stores the validation set of ImageNet.
|
| 98 |
batch_size: Batch size of data loader.
|
| 99 |
workers: How many subprocesses to use for data loading.
|
| 100 |
-
|
| 101 |
Returns:
|
| 102 |
An object of torch.utils.data.DataLoader.
|
| 103 |
"""
|
|
@@ -144,7 +141,7 @@ def val_imagenet():
|
|
| 144 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
| 145 |
with torch.no_grad():
|
| 146 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
| 147 |
-
inputs, targets = images.numpy(), targets
|
| 148 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
| 149 |
|
| 150 |
outputs = ort_session.run(None, ort_inputs)
|
|
|
|
| 68 |
output: Prediction of the model.
|
| 69 |
target: Ground truth labels.
|
| 70 |
topk: Topk accuracy to compute.
|
|
|
|
| 71 |
Returns:
|
| 72 |
Accuracy results according to 'topk'.
|
| 73 |
"""
|
|
|
|
| 90 |
batch_size: int = 100,
|
| 91 |
workers: int = 8) -> torch.utils.data.DataLoader:
|
| 92 |
"""Returns a validation data loader of ImageNet by given `data_dir`.
|
|
|
|
| 93 |
Args:
|
| 94 |
data_dir: Directory where images stores. There must be a subdirectory named
|
| 95 |
'validation' that stores the validation set of ImageNet.
|
| 96 |
batch_size: Batch size of data loader.
|
| 97 |
workers: How many subprocesses to use for data loading.
|
|
|
|
| 98 |
Returns:
|
| 99 |
An object of torch.utils.data.DataLoader.
|
| 100 |
"""
|
|
|
|
| 141 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
| 142 |
with torch.no_grad():
|
| 143 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
| 144 |
+
inputs, targets = images.numpy().transpose(0, 2, 3, 1), targets
|
| 145 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
| 146 |
|
| 147 |
outputs = ort_session.run(None, ort_inputs)
|