ssh-download / image_embed.py
XiN0919's picture
Upload folder using huggingface_hub
cd77de9 verified
raw
history blame
3.76 kB
import os
import json
import pickle
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch.multiprocessing as mp
# Paths and settings
INPUT_JSON = "Pretrain.json"
mean_shift = True # Enable full-pass mean shifting
CKPT = "/root/autodl-tmp/model/siglip2"
BATCH_SIZE = 512
LOAD_LIMIT = None # Limit number of items, or None for all
# Output directories
RAW_DIR = "raw_embeds"
SHIFTED_DIR = "shifted_embeds"
# Create output directories if they don't exist
os.makedirs(RAW_DIR, exist_ok=True)
os.makedirs(SHIFTED_DIR, exist_ok=True)
# 1. Load data
with open(INPUT_JSON, "r", encoding="utf-8") as f:
items = json.load(f)
if LOAD_LIMIT is not None:
items = items[:LOAD_LIMIT]
# 2. Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(CKPT)
# 3. Split data among GPUs
num_gpus = torch.cuda.device_count()
chunks = np.array_split(items, num_gpus)
# Function to compute raw embeddings (no shift)
def compute_raw_embeddings(device, data_chunk, gpu_id):
device = torch.device(device)
model = AutoModel.from_pretrained(CKPT).to(device).eval()
results = [] # To store raw embeddings
for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"):
batch = data_chunk[i:i + BATCH_SIZE]
ids = [it['id'] for it in batch]
captions = [it.get('caption', '') for it in batch]
inputs = tokenizer(
captions,
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt"
).to(device)
with torch.no_grad():
embs = model.get_text_features(**inputs)
embs_np = embs.cpu().numpy()
for idx, item_id in enumerate(ids):
results.append({'id': item_id, 'embed': embs_np[idx]})
# Save raw embeddings
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
with open(raw_file, 'wb') as f:
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}")
# Function to apply mean shift and save final embeddings
def apply_mean_shift_and_save(global_mean, gpu_id):
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl")
with open(raw_file, 'rb') as f:
data = pickle.load(f)
# Subtract global mean
for item in data:
item['embed'] = item['embed'] - global_mean
# Save shifted embeddings
with open(out_file, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}")
# Main entry
def main():
# 1st pass: compute raw embeddings in parallel
procs = []
for i in range(num_gpus):
p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i))
p.start()
procs.append(p)
for p in procs:
p.join()
if mean_shift:
# Load all raw embeddings to compute global mean
all_embeds = []
for i in range(num_gpus):
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl")
with open(raw_file, 'rb') as f:
data = pickle.load(f)
all_embeds.extend([item['embed'] for item in data])
all_embeds = np.stack(all_embeds, axis=0)
global_mean = np.mean(all_embeds, axis=0)
print("Computed global mean of shape", global_mean.shape)
# 2nd pass: subtract mean and save shifted embeddings
for i in range(num_gpus):
apply_mean_shift_and_save(global_mean, i)
if __name__ == "__main__":
main()