MLLMSeg: Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder

This repository contains the MLLMSeg_InternVL2_5_8B_RES model presented in the paper Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder.

Abstract: Reference Expression Segmentation (RES) aims to segment image regions specified by referring expressions and has become popular with the rise of multimodal large models (MLLMs). While MLLMs excel in semantic understanding, their token-generation paradigm struggles with pixel-level dense prediction. Existing RES methods either couple MLLMs with the parameter-heavy Segment Anything Model (SAM) with 632M network parameters or adopt SAM-free lightweight pipelines that sacrifice accuracy. To address the trade-off between performance and cost, we specifically propose MLLMSeg, a novel framework that fully exploits the inherent visual detail features encoded in the MLLM vision encoder without introducing an extra visual encoder. Besides, we propose a detail-enhanced and semantic-consistent feature fusion module (DSFF) that fully integrates the detail-related visual feature with the semantic-related feature output by the large language model (LLM) of MLLM. Finally, we establish a light-weight mask decoder with only 34M network parameters that optimally leverages detailed spatial features from the visual encoder and semantic features from the LLM to achieve precise mask prediction. Extensive experiments demonstrate that our method generally surpasses both SAM-based and SAM-free competitors, striking a better balance between performance and cost.

Code: Find the official implementation and full details on GitHub: https://github.com/jcwang0602/MLLMSeg ArXiv: http://arxiv.org/abs/2508.04107


Quick Start

This section provides instructions on how to inference our pre-trained models.

Notes: Our models accept images of any size as input. The model outputs are normalized to relative coordinates within a 0-1000 range (either a center point or a bounding box defined by top-left and bottom-right coordinates). For visualization, please remember to convert these relative coordinates back to the original image dimensions.

Installation

First, install the transformers library and other necessary dependencies as specified by the original repository:

conda create -n mllmseg python==3.10.18 -y
conda activate mllmseg
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
# If you encounter any problems during the installation of datasets, please install this first.
# conda install -c conda-forge pyarrow
pip install -r requirements.txt
pip install flash-attn==2.3.6 --no-build-isolation # Note: need gpu to install

Inference Example

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import os

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

path = 'jcwang0602/MLLMSeg_InternVL2_5_8B_RES' # or 'jcwang0602/MLLMSeg_InternVL2_5_8B_GRES'
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

# Example image. You need to download an example image, e.g., from the GitHub repo's assets:
# https://raw.githubusercontent.com/jcwang0602/MLLMSeg/main/examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png
# Save it to a local path like 'examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png'
# For demonstration, you might need to create dummy directories or replace the path.
image_path = './examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png'
if not os.path.exists(image_path):
    print(f"Warning: Image not found at {image_path}. Please download it from the GitHub repo.")
    # Fallback for demonstration if image is not present
    dummy_image_data = np.zeros((1024, 768, 3), dtype=np.uint8)
    dummy_image = Image.fromarray(dummy_image_data)
    pixel_values = build_transform(input_size=448)(dummy_image).unsqueeze(0).to(torch.bfloat16).cuda()
else:
    pixel_values = load_image(image_path, max_num=6).to(torch.bfloat16).cuda()

generation_config = dict(max_new_tokens=1024, do_sample=True)

question = "In the screenshot of this web page, please give me the coordinates of the element I want to click on according to my instructions(with point).\
\\\"'Champions League' link\\\""
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(f'User: {question}
Assistant: {response}')

Performance Metrics

Referring Expression Segmentation

Referring Expression Comprehension

Generalized Referring Expression Segmentation


Visualization

Referring Expression Segmentation

Referring Expression Comprehension

Generalized Referring Expression Segmentation


Citation

If our work is useful for your research, please consider citing:

@misc{wang2025unlockingpotentialmllmsreferring,
      title={Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder}, 
      author={Jingchao Wang and Zhijian Wu and Dingjiang Huang and Yefeng Zheng and Hong Wang},
      year={2025},
      eprint={2508.04107},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2508.04107}, 
}
Downloads last month
15
Safetensors
Model size
8.11B params
Tensor type
BF16
·
F16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for jcwang0602/MLLMSeg_InternVL2_5_8B_RES

Finetuned
(18)
this model

Datasets used to train jcwang0602/MLLMSeg_InternVL2_5_8B_RES