Update README.md
Browse files
README.md
CHANGED
@@ -76,20 +76,22 @@ pip install transformers torch decord soundfile qwen_omni_utils
|
|
76 |
### Inference
|
77 |
|
78 |
```python
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
)
|
|
|
|
|
93 |
|
94 |
# Example video path (replace with your actual video file path)
|
95 |
video_path = "path/to/your/video.mp4"
|
@@ -116,30 +118,159 @@ messages = [
|
|
116 |
}
|
117 |
]
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
|
|
|
|
|
|
123 |
|
124 |
-
|
125 |
-
text=[text],
|
126 |
-
videos=[video_path], # Pass video path directly to the processor
|
127 |
-
return_tensors="pt",
|
128 |
-
).to(model.device)
|
129 |
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
# Decode generated text
|
134 |
-
output_text = processor.batch_decode(
|
135 |
-
generated_ids[:, inputs["input_ids"].shape[1]:], # Exclude prompt from output
|
136 |
-
skip_special_tokens=True,
|
137 |
-
clean_up_tokenization_spaces=False
|
138 |
-
)[0]
|
139 |
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
```
|
142 |
|
|
|
143 |
## Evaluation
|
144 |
|
145 |
### Final Caption prompt (for inference)
|
|
|
76 |
### Inference
|
77 |
|
78 |
```python
|
79 |
+
import soundfile as sf
|
80 |
+
|
81 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
|
82 |
+
from qwen_omni_utils import process_mm_info
|
83 |
+
|
84 |
+
model = Qwen2_5OmniForConditionalGeneration.from_pretrained("openinterx/UGC-VideoCaptioner", torch_dtype="auto", device_map="auto")
|
85 |
+
|
86 |
+
# We recommend enabling flash_attention_2 for better acceleration and memory saving.
|
87 |
+
# model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
88 |
+
# "Qwen/Qwen2.5-Omni-3B",
|
89 |
+
# torch_dtype="auto",
|
90 |
+
# device_map="auto",
|
91 |
+
# attn_implementation="flash_attention_2",
|
92 |
+
# )
|
93 |
+
|
94 |
+
processor = Qwen2_5OmniProcessor.from_pretrained("openinterx/UGC-VideoCaptioner")
|
95 |
|
96 |
# Example video path (replace with your actual video file path)
|
97 |
video_path = "path/to/your/video.mp4"
|
|
|
118 |
}
|
119 |
]
|
120 |
|
121 |
+
|
122 |
+
# set use audio in video
|
123 |
+
USE_AUDIO_IN_VIDEO = True
|
124 |
+
|
125 |
+
# Preparation for inference
|
126 |
+
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
127 |
+
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
128 |
+
inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
129 |
+
inputs = inputs.to(model.device).to(model.dtype)
|
130 |
+
|
131 |
+
# Inference: Generation of the output text and audio
|
132 |
+
text_ids, audio = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
133 |
+
|
134 |
+
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
135 |
+
print(text)
|
136 |
+
sf.write(
|
137 |
+
"output.wav",
|
138 |
+
audio.reshape(-1).detach().cpu().numpy(),
|
139 |
+
samplerate=24000,
|
140 |
)
|
141 |
+
```
|
142 |
+
|
143 |
+
|
144 |
|
145 |
+
### vllm inference
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
```python
|
148 |
+
# pip install vllm
|
149 |
+
# pip install transformers==4.52.3
|
150 |
+
|
151 |
+
|
152 |
+
import os
|
153 |
+
import json
|
154 |
+
import re
|
155 |
+
from tqdm import tqdm
|
156 |
+
from vllm import LLM, SamplingParams
|
157 |
+
from vllm.assets.video import VideoAsset
|
158 |
+
from vllm.utils import FlexibleArgumentParser
|
159 |
+
|
160 |
+
VIDEO_DIR = "/workspace/benchmark/video"
|
161 |
+
OUTPUT_JSONL = "/workspace/benchmark/omni_vllm_sft_result_same_parameter.jsonl"
|
162 |
+
USE_AUDIO_IN_VIDEO = True
|
163 |
+
MAX_RETRY = 3
|
164 |
+
|
165 |
+
# Ensure output file exists
|
166 |
+
def ensure_output_file(path: str):
|
167 |
+
if not os.path.exists(path):
|
168 |
+
open(path, "w", encoding="utf-8").close()
|
169 |
+
|
170 |
+
# Load processed video IDs to skip
|
171 |
+
|
172 |
+
def load_processed_ids(jsonl_path: str) -> set[str]:
|
173 |
+
processed = set()
|
174 |
+
with open(jsonl_path, "r", encoding="utf-8") as fin:
|
175 |
+
for line in fin:
|
176 |
+
try:
|
177 |
+
data = json.loads(line)
|
178 |
+
vid = data.get("video_id")
|
179 |
+
if vid:
|
180 |
+
processed.add(vid)
|
181 |
+
except json.JSONDecodeError:
|
182 |
+
continue
|
183 |
+
return processed
|
184 |
+
|
185 |
+
# Regex to verify level tag at end of caption
|
186 |
+
# 没有level
|
187 |
+
LEVEL_PATTERN = re.compile(r"<level>[A-F]</level>\s*$")
|
188 |
+
|
189 |
+
PROMPT_TEMPLATE = (
|
190 |
+
f"<|im_start|>system\n" +
|
191 |
+
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech. Please make sure that the content within the answer is long and detailed enough." +
|
192 |
+
"<|im_end|>\n"
|
193 |
+
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
|
194 |
+
"You are given a short video with both audio and visual content. Write a detailed and coherent paragraph that naturally integrates all modalities. "
|
195 |
+
"Your description should include: (1) the primary scene and background setting; (2) key characters or objects and their actions or interactions; "
|
196 |
+
"(3) significant audio cues such as voices, background music, sound effects, and their emotional tone; "
|
197 |
+
"(4) any on-screen text (OCR) and its role in the video context; and (5) the overall theme or purpose of the video. "
|
198 |
+
"Ensure the output is a fluent and objective paragraph, not a bullet-point list, and captures the video's content in a human-like, narrative style. <|im_end|>\n"
|
199 |
+
"<|im_start|>assistant\n"
|
200 |
+
)
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
+
def process_video_folder(model_name: str, seed: int = None):
|
204 |
+
ensure_output_file(OUTPUT_JSONL)
|
205 |
+
video_files = sorted(f for f in os.listdir(VIDEO_DIR) if f.lower().endswith(".mp4"))
|
206 |
+
processed_ids = load_processed_ids(OUTPUT_JSONL)
|
207 |
+
|
208 |
+
llm = LLM(
|
209 |
+
model=model_name,
|
210 |
+
max_model_len=20000,
|
211 |
+
max_num_seqs=5,
|
212 |
+
limit_mm_per_prompt={"video": 1, "audio": 1},
|
213 |
+
seed=seed,
|
214 |
+
)
|
215 |
+
sampling_params = SamplingParams(temperature=0.2, max_tokens=1024)
|
216 |
+
|
217 |
+
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout:
|
218 |
+
for fname in tqdm(video_files, desc="Processing videos"):
|
219 |
+
video_id = os.path.splitext(fname)[0]
|
220 |
+
if video_id in processed_ids:
|
221 |
+
print(f"[Skip] {fname} already processed, skipping.")
|
222 |
+
continue
|
223 |
+
|
224 |
+
fpath = os.path.join(VIDEO_DIR, fname)
|
225 |
+
valid_caption = None
|
226 |
+
try:
|
227 |
+
video_asset = VideoAsset(path=fpath, num_frames=32)
|
228 |
+
audio = video_asset.get_audio(sampling_rate=16000)
|
229 |
+
|
230 |
+
inputs = {
|
231 |
+
"prompt": PROMPT_TEMPLATE,
|
232 |
+
"multi_modal_data": {"video": video_asset.np_ndarrays, "audio": audio},
|
233 |
+
"mm_processor_kwargs": {"use_audio_in_video": USE_AUDIO_IN_VIDEO},
|
234 |
+
}
|
235 |
+
|
236 |
+
for attempt in range(MAX_RETRY):
|
237 |
+
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
238 |
+
text = outputs[0].outputs[0].text.strip()
|
239 |
+
if text and LEVEL_PATTERN.search(text):
|
240 |
+
valid_caption = text
|
241 |
+
break
|
242 |
+
else:
|
243 |
+
print(f"[Retry] Attempt {attempt+1} for {fname} did not end with level tag, retrying...")
|
244 |
+
|
245 |
+
if not valid_caption:
|
246 |
+
print(f"[Warning] {fname} failed to get valid level tag after {MAX_RETRY} attempts, skipping.")
|
247 |
+
continue
|
248 |
+
|
249 |
+
fout.write(json.dumps({"video_id": video_id, "caption": valid_caption}, ensure_ascii=False) + "\n")
|
250 |
+
fout.flush()
|
251 |
+
processed_ids.add(video_id)
|
252 |
+
|
253 |
+
except Exception as e:
|
254 |
+
print(f"[Error] Failed to process {fname}: {e}")
|
255 |
+
continue
|
256 |
+
|
257 |
+
print(f"✅ Done! Processed videos with skipping and level validation. Output written to {OUTPUT_JSONL}")
|
258 |
+
|
259 |
+
|
260 |
+
def parse_args():
|
261 |
+
parser = FlexibleArgumentParser(description="Batch inference for a folder of videos using Qwen2.5-Omni + vLLM.")
|
262 |
+
parser.add_argument("--model-name", type=str, default="/workspace/output_model/tiktok_caption/omni_sft_20k_level/v1-20250701-150049/checkpoint-2404-merged", help="Model path or name.")
|
263 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")
|
264 |
+
return parser.parse_args()
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == "__main__":
|
268 |
+
args = parse_args()
|
269 |
+
process_video_folder(args.model_name, args.seed)
|
270 |
+
|
271 |
```
|
272 |
|
273 |
+
|
274 |
## Evaluation
|
275 |
|
276 |
### Final Caption prompt (for inference)
|