merve HF Staff commited on
Commit
26e8458
·
verified ·
1 Parent(s): b1f181c

Upload gemma3n_fine_tuning_on_all_modalities.py

Browse files
gemma3n_fine_tuning_on_all_modalities.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Gemma3n Fine-tuning on All Modalities.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1iEZUJuvKJpGU8t50BqfkiCQmGkaR6gd4
8
+
9
+ # Fine-tune Gemma3n on FineVideo
10
+
11
+ In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside.
12
+ Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM.
13
+ """
14
+
15
+ !pip install -U -q timm transformers trl peft datasets
16
+
17
+ import io
18
+ import os
19
+ import zipfile
20
+
21
+ import torch
22
+ from datasets import load_dataset
23
+ from PIL import Image
24
+ from transformers import AutoProcessor, Gemma3nForConditionalGeneration
25
+
26
+ from trl import (
27
+ SFTConfig,
28
+ SFTTrainer,
29
+ )
30
+
31
+ """## Download videos and preprocessing
32
+
33
+ FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs.
34
+ """
35
+
36
+ from datasets import load_dataset
37
+ import json
38
+ import os
39
+
40
+ dataset = load_dataset("HuggingFaceFV/finevideo", split="train", streaming=True)
41
+
42
+
43
+ os.makedirs("videos", exist_ok=True)
44
+ os.makedirs("metadata", exist_ok=True)
45
+
46
+ for idx, sample in enumerate(dataset):
47
+ data = sample["json"]
48
+ duration = data.get("duration_seconds", 0)
49
+ if duration < 30:
50
+ video_filename = f"videos/sample_{idx}.mp4"
51
+ with open(video_filename, 'wb') as video_file:
52
+ video_file.write(sample['mp4'])
53
+
54
+ json_filename = f"metadata/sample_{idx}.json"
55
+ with open(json_filename, 'w') as json_file:
56
+ json.dump(sample['json'], json_file)
57
+
58
+ print(f"Number of items in content/videos: {len(os.listdir('videos'))}")
59
+
60
+ """In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them."""
61
+
62
+ import cv2
63
+ from PIL import Image
64
+ import numpy as np
65
+
66
+ def is_dark(frame, threshold=10):
67
+ return np.max(frame) < threshold # all pixels are very close to 0
68
+
69
+ def downsample_video(video_path):
70
+ vidcap = cv2.VideoCapture(video_path)
71
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
72
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
73
+
74
+ frames = []
75
+
76
+ # Generate 8 evenly spaced indices, skip first and last
77
+ full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1]
78
+
79
+ for i in full_indices:
80
+ found_valid = False
81
+ for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark
82
+ candidate_idx = i + offset
83
+ if 0 <= candidate_idx < total_frames:
84
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx)
85
+ success, image = vidcap.read()
86
+ if success:
87
+ if not is_dark(image):
88
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
89
+ pil_image = Image.fromarray(image)
90
+ timestamp = round(candidate_idx / fps, 2)
91
+ frames.append((pil_image, timestamp))
92
+ found_valid = True
93
+ break
94
+ if not found_valid:
95
+ print(f"Warning: Could not find non-dark frame near index {i}")
96
+
97
+ vidcap.release()
98
+
99
+ # If still fewer than 8, try to top off by scanning more frames
100
+ if len(frames) < 6:
101
+ print("Trying to top off with additional non-dark frames...")
102
+ idx = 0
103
+ while len(frames) < 8 and idx < total_frames:
104
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx)
105
+ success, image = vidcap.read()
106
+ if success and not is_dark(image):
107
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108
+ pil_image = Image.fromarray(image)
109
+ timestamp = round(idx / fps, 2)
110
+ # Avoid adding duplicate timestamps
111
+ if not any(ts == timestamp for _, ts in frames):
112
+ frames.append((pil_image, timestamp))
113
+ idx += 1
114
+
115
+ return frames[:8] # Ensure exactly 8 frames
116
+
117
+ import os
118
+ import glob
119
+
120
+ def remove_dark_videos(video_dir, metadata_dir, audio_dir):
121
+ """
122
+ Remove videos (and their metadata/audio files) if all frames are dark.
123
+ """
124
+ video_paths = glob.glob(os.path.join(video_dir, "*.mp4"))
125
+
126
+ for video_path in video_paths:
127
+ filename = os.path.basename(video_path)
128
+ base_name = os.path.splitext(filename)[0]
129
+
130
+ frames = downsample_video(video_path)
131
+ if len(frames) < 6:
132
+ try:
133
+ os.remove(video_path)
134
+ print(f"Deleted: {video_path}")
135
+ except Exception as e:
136
+ print(f"Failed to delete {video_path}: {e}")
137
+
138
+ metadata_path = os.path.join(metadata_dir, f"{base_name}.json")
139
+ if os.path.exists(metadata_path):
140
+ os.remove(metadata_path)
141
+
142
+ # Remove audio
143
+ audio_path = os.path.join(audio_dir, f"{base_name}.wav")
144
+ if os.path.exists(audio_path):
145
+ os.remove(audio_path)
146
+
147
+ remove_dark_videos(
148
+ video_dir="videos",
149
+ metadata_dir="metadata",
150
+ audio_dir="audios"
151
+ )
152
+
153
+ """Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video."""
154
+
155
+ import os
156
+ import subprocess
157
+
158
+ video_dir = "videos"
159
+ audio_dir = "audios"
160
+ os.makedirs(audio_dir, exist_ok=True)
161
+
162
+ for filename in os.listdir(video_dir):
163
+ if not filename.endswith(".mp4"):
164
+ continue
165
+
166
+ idx = filename.split("_")[1].split(".")[0]
167
+ video_path = os.path.join(video_dir, filename)
168
+ audio_path = os.path.join(audio_dir, f"sample_{idx}.wav")
169
+
170
+ subprocess.run([
171
+ "ffmpeg", "-i", video_path,
172
+ "-q:a", "0", "-map", "a",
173
+ audio_path,
174
+ "-y"
175
+ ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
176
+
177
+ """Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes."""
178
+
179
+ from datasets import Dataset
180
+ import json
181
+
182
+ def gen():
183
+ meta_dir = "metadata"
184
+ for filename in os.listdir(meta_dir):
185
+ if not filename.endswith(".json"):
186
+ continue
187
+
188
+ idx = filename.split("_")[1].split(".")[0]
189
+ if os.path.exists(f"videos/sample_{idx}.mp4"):
190
+ video_filename = f"sample_{idx}.mp4"
191
+ audio_filename = f"sample_{idx}.wav"
192
+ json_path = os.path.join(meta_dir, filename)
193
+
194
+ with open(json_path, "r") as f:
195
+ metadata = json.load(f)
196
+
197
+
198
+ yield {
199
+ "video": video_filename,
200
+ "audio": audio_filename,
201
+ "content_parent_category": metadata["content_parent_category"],
202
+ "sample_index": int(idx)
203
+ }
204
+ else:
205
+ pass
206
+
207
+ dataset = Dataset.from_generator(gen)
208
+
209
+ """We will speed-up and downsample the audios to save space during training."""
210
+
211
+ import torchaudio
212
+ from torchaudio.transforms import Resample
213
+ import os
214
+ import torch
215
+
216
+ def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25):
217
+ waveform, sample_rate = torchaudio.load(audio_path)
218
+
219
+ if waveform.shape[0] > 1:
220
+ waveform = waveform.mean(dim=0, keepdim=True)
221
+
222
+ if sample_rate != target_sample_rate:
223
+ resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
224
+ waveform = resampler(waveform)
225
+ sample_rate = target_sample_rate
226
+
227
+ if speedup_factor > 1.0:
228
+ indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long()
229
+ if indices[-1] >= waveform.shape[1]:
230
+ indices = indices[:-1]
231
+ waveform = waveform[:, indices]
232
+
233
+ max_length = int(target_sample_rate * max_duration_sec)
234
+ if waveform.shape[1] > max_length:
235
+ waveform = waveform[:, :max_length]
236
+
237
+ torchaudio.save(audio_path, waveform, sample_rate)
238
+
239
+ for file_name in os.listdir("audios"):
240
+ if file_name.lower().endswith(".wav"):
241
+ audio_path = os.path.join("audios", file_name)
242
+ preprocess_audio(audio_path)
243
+
244
+ dataset = dataset.train_test_split(test_size=0.10, seed=42)
245
+
246
+ """### Load the model
247
+
248
+ Make sure you have your Hugging Face token in your Colab secrets.
249
+ """
250
+
251
+ model = Gemma3nForConditionalGeneration.from_pretrained(
252
+ "google/gemma-3n-E2B-it", torch_dtype=torch.bfloat16,
253
+ )
254
+ processor = AutoProcessor.from_pretrained(
255
+ "google/gemma-3n-E2B-it",
256
+ )
257
+ processor.tokenizer.padding_side = "right"
258
+
259
+ processor.tokenizer.all_special_ids
260
+
261
+ """Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities.
262
+
263
+ In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames.
264
+ """
265
+
266
+ def collate_fn(examples):
267
+ video_path = examples[0]["video"]
268
+ audio_path = examples[0]["audio"]
269
+ sample_idx = filename.split("_")[1].split(".")[0]
270
+ frames = downsample_video(f"videos/{video_path}")
271
+
272
+ text = "Based on the video, predict the category of it."
273
+ message = [
274
+ {
275
+ "role": "user",
276
+ "content": [
277
+ {"type": "text", "text": text}
278
+ ],
279
+ },
280
+ ]
281
+ # this is how video inference should be formatted in Gemma3n
282
+ for frame in frames:
283
+ image, timestamp = frame
284
+ message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
285
+ timestamp = str(timestamp).replace(".", "_")
286
+ image.save(f"image_idx_{sample_idx}_{timestamp}.png")
287
+ message[0]["content"].append({"type": "image", "url": f"image_idx_{sample_idx}_{timestamp}.png"})
288
+
289
+ message[0]["content"].append({"type": "audio", "audio": f"audios/{audio_path}"})
290
+ message.append({"role": "assistant", "content": [{"type": "text", "text": examples[0]["content_parent_category"]}]})
291
+ inputs = processor.apply_chat_template(
292
+ message,
293
+ add_generation_prompt=False,
294
+ tokenize=True,
295
+ return_dict=True,
296
+ return_tensors="pt",
297
+ padding=True,
298
+ ).to(model.device)
299
+
300
+ labels = inputs["input_ids"].clone()
301
+ special_token_ids = processor.tokenizer.all_special_ids
302
+
303
+ special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)
304
+ mask = torch.isin(labels, special_token_ids_tensor)
305
+ labels[mask] = -100
306
+
307
+ inputs["labels"] = labels
308
+ if torch.all(inputs["pixel_values"] == 0):
309
+ print("Frames are dark")
310
+
311
+ return inputs
312
+
313
+ """## Training
314
+
315
+ We do LoRA fine-tuning again to save up on space.
316
+ """
317
+
318
+ from peft import LoraConfig
319
+ peft_config = LoraConfig(
320
+ task_type="CAUSAL_LM",
321
+ r=16,
322
+ target_modules="all-linear",
323
+ lora_alpha=32,
324
+ lora_dropout=0.05,
325
+ bias="none",
326
+ use_rslora=False,
327
+ use_dora=False,
328
+ modules_to_save=None
329
+ )
330
+
331
+ model.gradient_checkpointing_disable()
332
+
333
+ model.config.use_cache = False
334
+
335
+ training_args = SFTConfig(
336
+ output_dir="/content/gemma-3n-finevideo",
337
+ eval_strategy='epoch',
338
+ per_device_train_batch_size=1,
339
+ per_device_eval_batch_size=1,
340
+ gradient_accumulation_steps=4,
341
+ gradient_checkpointing=False,
342
+ learning_rate=1e-05,
343
+ num_train_epochs=3.0,
344
+ logging_steps=10,
345
+ save_steps=100,
346
+ bf16=True,
347
+ report_to=["tensorboard"],
348
+ dataset_kwargs={'skip_prepare_dataset': True},
349
+ remove_unused_columns=False,
350
+ max_seq_length=None,
351
+ push_to_hub=True,
352
+ dataloader_pin_memory=False,
353
+ )
354
+
355
+ trainer = SFTTrainer(
356
+ model=model,
357
+ args=training_args,
358
+ data_collator=collate_fn,
359
+ train_dataset=dataset["train"],
360
+ eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None,
361
+ processing_class=processor.tokenizer,
362
+ peft_config=peft_config,
363
+ )
364
+
365
+ trainer.train()
366
+
367
+ """Test the model with a video of snowboarding."""
368
+
369
+ !wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4
370
+
371
+ model = trainer.model # trainer has the adapter
372
+
373
+ """Strip audio and downsample video."""
374
+
375
+ audio_path = "/content/test_audio.wav"
376
+ subprocess.run([
377
+ "ffmpeg", "-i", "/content/IMG_8137.mp4",
378
+ "-q:a", "0", "-map", "a",
379
+ f"{audio_path}",
380
+ "-y"
381
+ ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
382
+
383
+ frames = downsample_video("/content/IMG_8137.mp4")
384
+
385
+ # repeat the chat template
386
+ text = "Based on the video, predict the category of it."
387
+ message = [
388
+ {
389
+ "role": "user",
390
+ "content": [
391
+ {"type": "text", "text": text}
392
+ ],
393
+ },
394
+ ]
395
+ for frame in frames:
396
+ image, timestamp = frame
397
+ message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
398
+ timestamp = str(timestamp).replace(".", "_")
399
+ image.save(f"test_frame_{timestamp}.png")
400
+ message[0]["content"].append({"type": "image", "url": f"test_frame_{timestamp}.png"})
401
+
402
+ message[0]["content"].append({"type": "audio", "audio": f"{audio_path}"})
403
+
404
+ message
405
+
406
+ inputs = processor.apply_chat_template(
407
+ message,
408
+ add_generation_prompt=True,
409
+ tokenize=True,
410
+ return_dict=True,
411
+ return_tensors="pt",
412
+ padding=True,
413
+ ).to(model.device).to(model.dtype)
414
+
415
+ input_len = inputs["input_ids"].shape[-1]
416
+
417
+ with torch.inference_mode():
418
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
419
+ generation = generation[0][input_len:]
420
+
421
+ decoded = processor.decode(generation, skip_special_tokens=True)
422
+ print(decoded)
423
+
424
+ """Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗"""
425
+