Fix `pos_embed` device
Browse files- modeling_aimv2.py +1 -1
modeling_aimv2.py
CHANGED
|
@@ -102,7 +102,7 @@ class AIMv2ViTPreprocessor(nn.Module):
|
|
| 102 |
pos_embed = get_sincos_pos_embed(
|
| 103 |
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
|
| 104 |
)
|
| 105 |
-
tokens = tokens + pos_embed
|
| 106 |
return tokens
|
| 107 |
|
| 108 |
|
|
|
|
| 102 |
pos_embed = get_sincos_pos_embed(
|
| 103 |
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
|
| 104 |
)
|
| 105 |
+
tokens = tokens + pos_embed.to(tokens.device)
|
| 106 |
return tokens
|
| 107 |
|
| 108 |
|