import torch from transformers import AutoProcessor, AutoModelForImageTextToText import base64 from PIL import Image import io import os class EndpointHandler: def __init__(self, path=""): # Le token est automatiquement disponible pour les endpoints privés/protégés token = os.getenv("HUGGING_FACE_HUB_TOKEN") # On spécifie quel modèle charger via une variable d'environnement # Si elle n'est pas définie, on prend une valeur par défaut model_id = os.getenv("MODEL_ID", "HuggingFaceTB/SmolVLM2-256M-Video-Instruct") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 # Charger le processeur et le modèle DEPUIS L'ID DU MODÈLE ORIGINAL self.processor = AutoProcessor.from_pretrained(model_id, token=token) self.model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=self.dtype, token=token ).to(self.device) print(f"✅ Modèle {model_id} chargé avec succès sur {self.device}") print("✅ Modèle et processeur chargés avec succès sur le device:", self.device) def __call__(self, data: dict) -> dict: """ Cette fonction est appelée pour chaque requête API. `data` est le JSON envoyé dans la requête. """ # Extraire les entrées du JSON de la requête inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Le payload attendu est une liste de messages, comme dans notre API # Exemple: {"inputs": [{"role": "user", "content": [{"type": "text", "text": "prompt"}, {"type": "video", "data": "base64_string"}]}]} messages = inputs # Le modèle attend un chemin de fichier, nous devons donc décoder la vidéo base64 # et la sauvegarder temporairement. video_content = messages[0]['content'][1]['data'] video_data = base64.b64decode(video_content) # Sauvegarder le fichier vidéo temporairement temp_video_path = "/tmp/temp_video.mp4" with open(temp_video_path, "wb") as f: f.write(video_data) # Mettre à jour le message pour pointer vers le chemin du fichier messages[0]['content'][1] = {"type": "video", "path": temp_video_path} # Préparer les entrées pour le modèle inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.device, dtype=self.dtype) # Exécuter l'inférence with torch.no_grad(): generated_ids = self.model.generate(**inputs, **parameters) generated_texts = self.processor.batch_decode( generated_ids, skip_special_tokens=True, ) # Nettoyer le fichier temporaire os.remove(temp_video_path) # Retourner le résultat return {"generated_text": generated_texts[0]}