File size: 3,195 Bytes
0666409
 
 
 
 
 
 
 
 
 
8dede70
0666409
8dede70
 
 
0666409
 
 
 
8dede70
 
0666409
8dede70
0666409
 
 
 
8dede70
 
0666409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import base64
from PIL import Image
import io
import os

class EndpointHandler:
    def __init__(self, path=""):
        # Le token est automatiquement disponible pour les endpoints privés/protégés
        token = os.getenv("HUGGING_FACE_HUB_TOKEN")
        # On spécifie quel modèle charger via une variable d'environnement
        # Si elle n'est pas définie, on prend une valeur par défaut
        model_id = os.getenv("MODEL_ID", "HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32

        # Charger le processeur et le modèle DEPUIS L'ID DU MODÈLE ORIGINAL
        self.processor = AutoProcessor.from_pretrained(model_id, token=token)
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_id,
            torch_dtype=self.dtype,
            token=token
        ).to(self.device)

        print(f"✅ Modèle {model_id} chargé avec succès sur {self.device}")

        print("✅ Modèle et processeur chargés avec succès sur le device:", self.device)

    def __call__(self, data: dict) -> dict:
        """
        Cette fonction est appelée pour chaque requête API.
        `data` est le JSON envoyé dans la requête.
        """
        # Extraire les entrées du JSON de la requête
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        
        # Le payload attendu est une liste de messages, comme dans notre API
        # Exemple: {"inputs": [{"role": "user", "content": [{"type": "text", "text": "prompt"}, {"type": "video", "data": "base64_string"}]}]}
        
        messages = inputs
        
        # Le modèle attend un chemin de fichier, nous devons donc décoder la vidéo base64
        # et la sauvegarder temporairement.
        video_content = messages[0]['content'][1]['data']
        video_data = base64.b64decode(video_content)
        
        # Sauvegarder le fichier vidéo temporairement
        temp_video_path = "/tmp/temp_video.mp4"
        with open(temp_video_path, "wb") as f:
            f.write(video_data)
            
        # Mettre à jour le message pour pointer vers le chemin du fichier
        messages[0]['content'][1] = {"type": "video", "path": temp_video_path}

        # Préparer les entrées pour le modèle
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.device, dtype=self.dtype)

        # Exécuter l'inférence
        with torch.no_grad():
            generated_ids = self.model.generate(**inputs, **parameters)
            generated_texts = self.processor.batch_decode(
                generated_ids,
                skip_special_tokens=True,
            )
        
        # Nettoyer le fichier temporaire
        os.remove(temp_video_path)

        # Retourner le résultat
        return {"generated_text": generated_texts[0]}