smol-256 / handler.py
swarecito's picture
Upload folder using huggingface_hub
8dede70 verified
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import base64
from PIL import Image
import io
import os
class EndpointHandler:
def __init__(self, path=""):
# Le token est automatiquement disponible pour les endpoints privés/protégés
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# On spécifie quel modèle charger via une variable d'environnement
# Si elle n'est pas définie, on prend une valeur par défaut
model_id = os.getenv("MODEL_ID", "HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
# Charger le processeur et le modèle DEPUIS L'ID DU MODÈLE ORIGINAL
self.processor = AutoProcessor.from_pretrained(model_id, token=token)
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=self.dtype,
token=token
).to(self.device)
print(f"✅ Modèle {model_id} chargé avec succès sur {self.device}")
print("✅ Modèle et processeur chargés avec succès sur le device:", self.device)
def __call__(self, data: dict) -> dict:
"""
Cette fonction est appelée pour chaque requête API.
`data` est le JSON envoyé dans la requête.
"""
# Extraire les entrées du JSON de la requête
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Le payload attendu est une liste de messages, comme dans notre API
# Exemple: {"inputs": [{"role": "user", "content": [{"type": "text", "text": "prompt"}, {"type": "video", "data": "base64_string"}]}]}
messages = inputs
# Le modèle attend un chemin de fichier, nous devons donc décoder la vidéo base64
# et la sauvegarder temporairement.
video_content = messages[0]['content'][1]['data']
video_data = base64.b64decode(video_content)
# Sauvegarder le fichier vidéo temporairement
temp_video_path = "/tmp/temp_video.mp4"
with open(temp_video_path, "wb") as f:
f.write(video_data)
# Mettre à jour le message pour pointer vers le chemin du fichier
messages[0]['content'][1] = {"type": "video", "path": temp_video_path}
# Préparer les entrées pour le modèle
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.device, dtype=self.dtype)
# Exécuter l'inférence
with torch.no_grad():
generated_ids = self.model.generate(**inputs, **parameters)
generated_texts = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
# Nettoyer le fichier temporaire
os.remove(temp_video_path)
# Retourner le résultat
return {"generated_text": generated_texts[0]}