LocateAnything-3B — ONNX WebGPU (INT4 + 4-bit embeddings)

In-browser (onnxruntime-web / WebGPU) build of nvidia/LocateAnything-3B, a visual-grounding / open-vocabulary detector. The language tower is weight-only INT4 and the embedding table is true group-wise 4-bit.

Why this repo exists

The naive "4-bit ONNX" of this model was ~3GB because the model has tied word embeddings (vocab 152681 × hidden 2048 = 1.25GB in fp32). ORT's MatMulNBits INT4 quantizer compresses the tied lm_head MatMul but leaves the input-embedding Gather at full fp32 — so 1.25GB of fp32 embeddings stayed in the package.

This build fixes that with a custom quantized embedding gather:

  1. The language graph was surgically rewired to consume inputs_embeds directly (the fp32 embedding Gather and its 1.25GB initializer are removed). It still takes input_ids (used only by the SDLM block-mask == comparisons, not a Gather) and visual_features (spliced at the image token).
  2. The embedding table ships as a group-wise symmetric INT4 blob ((q-8)·scale, block size 32): embed_tokens_int4_packed.bin (uint8 nibble-packed) + embed_tokens_int4_scales.bin (fp16).
  3. The browser does the gather + dequant in JS to build inputs_embeds, then runs the INT4 language graph.

Files (browser-facing)

File Size Notes
onnx/vision_mlp.onnx (+.data) ~1.73 GB MoonViT + projector, fp32 (see note below)
onnx/language_tail_int4.onnx (+.data) ~1.69 GB Qwen2 language tower + tied lm_head, weight-only INT4 (block 128)
onnx/embed_tokens_int4_packed.bin ~156 MB INT4 embedding table, uint8 nibble-packed [152681, 1024]
onnx/embed_tokens_int4_scales.bin ~19.5 MB fp16 group scales [152681, 64]
onnx/embed_tokens_int4_meta.json layout / dequant scheme
web_config.json runtime wiring, token ids, tail size

Total browser payload ≈ 3.6 GB. The big win here is the language side: the embedding table dropped from 1.25 GB fp32 → 176 MB INT4 and the language tail from 2.9 GB → 1.69 GB.

Vision precision note. The vision tower's linears export as ONNX Gemm / dynamic MatMul, which ORT's MatMulNBits INT4 quantizer cannot compress, so it ships fp32. fp16 conversion is blocked by the explicit .float() Cast islands in the ONNX-friendly MoonViT RoPE patch (post-hoc fp16 conversion produces type clashes; native fp16 export hits a torch/MPS expand_as+float64 limitation). A native mixed-precision vision re-export (Conv in fp32, rest fp16) is the planned follow-up to cut this to ~0.9 GB.

Embedding gather / dequant (JS reference)

row   = packed[token_id]                 // uint8[hidden/2]
low   = row & 0x0F ; high = (row >> 4)   // two nibbles per byte (low = even idx, high = odd idx)
q     = interleave(low, high)            // uint4[hidden], values 0..15
emb   = (q - 8) * scales[token_id][j/32] // fp32[hidden]; one scale per 32-wide group

The language graph then splices visual_features over the image-token positions and applies the SDLM block mask from input_ids.

Validation

Validated against the fp32 PyTorch model on the sample image (slow / autoregressive mode):

  • Next-token argmax matches PyTorch exactly (token 151672 = <ref> start).
  • INT4 embedding gather error vs fp32 embeddings: max_abs 0.017, mean_rel ≈ 10% (per element), contributing only ~0.98 to the final logits (argmax-preserving).
  • The dominant INT4 weight error (~12.6 max logit delta) is unchanged from the baseline INT4 build.
  • fp16 vision vs fp32 vision: see validation_report.json.

Generation mode: use slow (autoregressive, greedy). The earlier prefill tail positions used by the fast/MTP path diverge under INT4 and are not relied upon here.

Intended use

Open-vocabulary detection / visual grounding: given an image and a category prompt (Locate all the instances that matches the following description: <category>.), the model emits <ref>label</ref><box>x1 y1 x2 y2</box> with coordinates normalized to 0–1000.

KV-cache graph for in-browser use

onnx/language_tail_kv_int4.onnx (+.data, ~1.65 GB) is the KV-cache version of the language tail used by the live demo. It takes inputs_embeds (+ input_ids for the plain-causal mask, position_ids, and 36×2 past_key/value GQA tensors [1,2,seq,128]) and returns logits for the last position plus present_key/value. Prefill passes length-0 past; decode passes the growing cache. This makes autoregressive decoding ~13× faster than the cache-less graph. Validated: prefill (empty past) → cached decode reproduces <ref>label</ref><box>…</box> detections; next-token argmax matches the fp32 torch model. See kv_validation_report.json.

Live in-browser demo (WebGPU): https://huggingface.co/spaces/Reza2kn/LocateAnything-3B-WebGPU

Source model & license: nvidia/LocateAnything-3B.

Downloads last month
124
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4

Base model

Qwen/Qwen2.5-3B
Quantized
(4)
this model

Space using Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4 1

Free AI Image Generator No sign-up. Instant results. Open Now