Upload model
Browse files- hf_model.py +6 -11
hf_model.py
CHANGED
|
@@ -23,18 +23,13 @@ from .model import create_model_from_args
|
|
| 23 |
from .input_conditioner import get_default_conditioner, InputConditioner
|
| 24 |
|
| 25 |
|
| 26 |
-
resource_map = {
|
| 27 |
-
'radio_v1': 'https://huggingface.co/nvidia/RADIO/raw/main/radio_v1.pth.tar'
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
-
|
| 31 |
class RADIOConfig(PretrainedConfig):
|
| 32 |
"""Pretrained Hugging Face configuration for RADIO models."""
|
| 33 |
|
| 34 |
def __init__(
|
| 35 |
self,
|
| 36 |
args: Optional[dict] = None,
|
| 37 |
-
version: Optional[str]="v1",
|
| 38 |
return_summary: Optional[bool] = True,
|
| 39 |
return_spatial_features: Optional[bool] = True,
|
| 40 |
**kwargs,
|
|
@@ -68,12 +63,12 @@ class RADIOModel(PreTrainedModel):
|
|
| 68 |
if isinstance(y, (list, tuple)):
|
| 69 |
summary, all_feat = y
|
| 70 |
elif isinstance(self.model, VisionTransformer):
|
| 71 |
-
patch_gen = getattr(self.model,
|
| 72 |
if patch_gen is not None:
|
| 73 |
-
summary = y[:, :patch_gen.num_cls_tokens].flatten(1)
|
| 74 |
-
all_feat = y[:, patch_gen.num_skip:]
|
| 75 |
-
elif self.model.global_pool ==
|
| 76 |
-
summary = y[:, self.model.num_prefix_tokens:].mean(dim=1)
|
| 77 |
all_feat = y
|
| 78 |
else:
|
| 79 |
summary = y[:, 0]
|
|
|
|
| 23 |
from .input_conditioner import get_default_conditioner, InputConditioner
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class RADIOConfig(PretrainedConfig):
|
| 27 |
"""Pretrained Hugging Face configuration for RADIO models."""
|
| 28 |
|
| 29 |
def __init__(
|
| 30 |
self,
|
| 31 |
args: Optional[dict] = None,
|
| 32 |
+
version: Optional[str] = "v1",
|
| 33 |
return_summary: Optional[bool] = True,
|
| 34 |
return_spatial_features: Optional[bool] = True,
|
| 35 |
**kwargs,
|
|
|
|
| 63 |
if isinstance(y, (list, tuple)):
|
| 64 |
summary, all_feat = y
|
| 65 |
elif isinstance(self.model, VisionTransformer):
|
| 66 |
+
patch_gen = getattr(self.model, "patch_generator", None)
|
| 67 |
if patch_gen is not None:
|
| 68 |
+
summary = y[:, : patch_gen.num_cls_tokens].flatten(1)
|
| 69 |
+
all_feat = y[:, patch_gen.num_skip :]
|
| 70 |
+
elif self.model.global_pool == "avg":
|
| 71 |
+
summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
|
| 72 |
all_feat = y
|
| 73 |
else:
|
| 74 |
summary = y[:, 0]
|