TomBombadyl's picture
Update handler.py
6a0c90b verified
raw
history blame
9.01 kB
from typing import Dict, List, Any
import torch
import json
import os
import glob
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize handler with robust file discovery
"""
logger.info(f"Loading model from {path}")
try:
# Log directory contents to understand structure
if os.path.exists(path):
contents = os.listdir(path)
logger.info(f"Repository contents: {contents}")
# Look for model files in subdirectories
for item in contents:
item_path = os.path.join(path, item)
if os.path.isdir(item_path):
sub_contents = os.listdir(item_path)
logger.info(f"Directory {item}: {sub_contents}")
# Try to find the actual model path
model_path = self._find_model_path(path)
logger.info(f"Using model path: {model_path}")
# Load tokenizer - try multiple approaches
self.tokenizer = self._load_tokenizer(model_path, path)
logger.info("Tokenizer loaded successfully")
# Load model
self.model = self._load_model(model_path, path)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to initialize: {str(e)}")
raise e
def _find_model_path(self, base_path: str) -> str:
"""Find the actual path containing model files"""
# Check if config.json is in base path
if os.path.exists(os.path.join(base_path, "config.json")):
return base_path
# Check models/huggingface subdirectory
hf_path = os.path.join(base_path, "models", "huggingface")
if os.path.exists(hf_path) and os.path.exists(os.path.join(hf_path, "config.json")):
return hf_path
# Check for any subdirectory with config.json
for root, dirs, files in os.walk(base_path):
if "config.json" in files:
return root
# Fallback to base path
return base_path
def _load_tokenizer(self, model_path: str, base_path: str):
"""Load tokenizer with fallback methods"""
try:
# Try direct loading from model path
logger.info(f"Trying to load tokenizer from {model_path}")
return AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
local_files_only=True
)
except Exception as e1:
logger.warning(f"Failed to load from {model_path}: {e1}")
try:
# Try loading from base path
logger.info(f"Trying to load tokenizer from {base_path}")
return AutoTokenizer.from_pretrained(
base_path,
trust_remote_code=True,
local_files_only=True
)
except Exception as e2:
logger.warning(f"Failed to load from {base_path}: {e2}")
try:
# Try loading from Hugging Face Hub as fallback
logger.info("Using fallback tokenizer from Qwen2-7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-7B-Instruct",
trust_remote_code=True
)
# Set special tokens
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
except Exception as e3:
logger.error(f"All tokenizer loading methods failed: {e3}")
raise e3
def _load_model(self, model_path: str, base_path: str):
"""Load model with fallback methods"""
try:
# Try direct loading from model path
logger.info(f"Trying to load model from {model_path}")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
local_files_only=True,
low_cpu_mem_usage=True
)
except Exception as e1:
logger.warning(f"Failed to load from {model_path}: {e1}")
try:
# Try loading from base path
logger.info(f"Trying to load model from {base_path}")
model = AutoModelForCausalLM.from_pretrained(
base_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
local_files_only=True,
low_cpu_mem_usage=True
)
except Exception as e2:
logger.error(f"Model loading failed from both paths: {e2}")
raise e2
model.eval()
return model
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle inference requests
"""
try:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
if not inputs:
return [{"error": "No input provided", "generated_text": ""}]
# Generation parameters with safety limits
max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
do_sample = parameters.get("do_sample", True)
# Format input for Qwen chat template
if inputs.startswith("<|im_start|>"):
formatted_input = inputs
else:
formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
# Tokenize with error handling
try:
input_ids = self.tokenizer.encode(
formatted_input,
return_tensors="pt",
truncation=True,
max_length=3072
)
except Exception as e:
logger.error(f"Tokenization failed: {e}")
return [{"error": f"Tokenization failed: {str(e)}", "generated_text": ""}]
if input_ids.size(1) == 0:
return [{"error": "Empty input after tokenization", "generated_text": ""}]
# Move to model device
input_ids = input_ids.to(next(self.model.parameters()).device)
# Generate with error handling
try:
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
num_return_sequences=1
)
except Exception as e:
logger.error(f"Generation failed: {e}")
return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]
# Decode response
try:
generated_ids = outputs[0][input_ids.size(1):]
response = self.tokenizer.decode(
generated_ids,
skip_special_tokens=True
).strip()
# Clean up response
response = response.replace("<|im_end|>", "").strip()
return [{
"generated_text": response,
"generated_tokens": len(generated_ids),
"finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
}]
except Exception as e:
logger.error(f"Decoding failed: {e}")
return [{"error": f"Decoding failed: {str(e)}", "generated_text": ""}]
except Exception as e:
logger.error(f"Inference error: {str(e)}")
return [{"error": f"Inference failed: {str(e)}", "generated_text": ""}]