TomBombadyl's picture
Update handler.py
55475f8 verified
from typing import Dict, List, Any
import torch
import json
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize handler using CTransformers format for memory efficiency
"""
logger.info(f"Loading model from {path}")
try:
# Use CTransformers format for lower memory usage
ctransformers_path = os.path.join(path, "models", "ctransformers")
if not os.path.exists(ctransformers_path):
logger.warning(f"CTransformers path not found: {ctransformers_path}")
logger.info("Falling back to HuggingFace format")
ctransformers_path = path
logger.info(f"Using model path: {ctransformers_path}")
# Load components using the working handler approach
self.tokenizer = self._load_tokenizer(ctransformers_path)
self.model = self._load_model(ctransformers_path)
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to initialize: {str(e)}")
raise e
def _load_tokenizer(self, model_path: str):
"""Load tokenizer using AutoTokenizer"""
logger.info("Loading tokenizer...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=True,
)
# Ensure special tokens are set
if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
logger.info("Tokenizer loaded successfully")
return tokenizer
def _load_model(self, model_path: str):
"""Load model using AutoModelForCausalLM with memory optimization"""
logger.info("Loading model with memory optimization...")
from transformers import AutoModelForCausalLM
# Check GPU availability
if torch.cuda.is_available():
logger.info(f"CUDA available: {torch.cuda.get_device_name()}")
logger.info(f"GPU memory total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB")
else:
logger.warning("CUDA not available, using CPU")
# Memory optimization settings
device_map = "auto" if torch.cuda.is_available() else None
gpu_mem = os.environ.get("GPU_MAX_MEM", "10GiB") # Conservative for 12GB limit
cpu_mem = os.environ.get("CPU_MAX_MEM", "24GiB")
max_memory = {0: gpu_mem, "cpu": cpu_mem} if torch.cuda.is_available() else None
# Offload folder for memory management
offload_folder = os.environ.get("OFFLOAD_FOLDER", "/tmp/hf-offload")
try:
os.makedirs(offload_folder, exist_ok=True)
except OSError:
offload_folder = "/tmp/hf-offload"
os.makedirs(offload_folder, exist_ok=True)
# Try to load with quantization first, fallback without if it fails
model = None
quantization_config = None
# Attempt 1: Try with 8-bit quantization (if bitsandbytes is available)
if torch.cuda.is_available():
try:
# Check if bitsandbytes is available
import bitsandbytes
from transformers import BitsAndBytesConfig
logger.info("bitsandbytes available, attempting 8-bit quantization...")
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
device_map=device_map,
quantization_config=bnb_config,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
max_memory=max_memory,
)
logger.info("Successfully loaded with 8-bit quantization")
quantization_config = "8-bit"
except ImportError as e:
logger.info(f"bitsandbytes not available ({str(e)}), skipping quantization...")
model = None
except Exception as e:
logger.warning(f"8-bit quantization failed: {str(e)}")
logger.info("Falling back to FP16 without quantization...")
model = None
# Attempt 2: Fallback to FP16 without quantization
if model is None:
try:
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
quantization_config=None, # Disable model's built-in quantization
low_cpu_mem_usage=True,
offload_folder=offload_folder if device_map == "auto" else None,
max_memory=max_memory,
)
logger.info("Successfully loaded with FP16 (no quantization)")
quantization_config = "fp16"
except Exception as e:
logger.warning(f"FP16 loading failed: {str(e)}")
logger.info("Falling back to FP32 CPU loading...")
model = None
# Attempt 3: Final fallback to CPU FP32
if model is None:
try:
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
quantization_config=None, # Disable model's built-in quantization
low_cpu_mem_usage=True,
)
logger.info("Successfully loaded with FP32 on CPU")
quantization_config = "fp32_cpu"
except Exception as e:
logger.error(f"All loading attempts failed: {str(e)}")
raise e
if model is None:
raise RuntimeError("Failed to load model with any configuration")
model.eval()
# Set context window
self.max_context = getattr(model.config, "max_position_embeddings", None) or getattr(self.tokenizer, "model_max_length", 4096)
if self.max_context is None or self.max_context == int(1e30):
self.max_context = 4096
# Set token IDs
self.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
self.eos_token_id = self.tokenizer.eos_token_id
logger.info(f"Model loaded successfully with {quantization_config} configuration")
return model
def _build_prompt(self, data: Dict[str, Any]) -> str:
"""Build prompt using chat template or direct input"""
# Accept either "messages" (chat) or "inputs"/"prompt" (single-turn)
if "messages" in data and isinstance(data["messages"], list):
return self.tokenizer.apply_chat_template(
data["messages"],
tokenize=False,
add_generation_prompt=True
)
user_text = data.get("inputs") or data.get("prompt") or ""
if isinstance(user_text, str):
messages = [{"role": "user", "content": user_text}]
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return str(user_text)
def _prepare_inputs(self, prompt: str, max_new_tokens: int, params: Dict[str, Any]) -> Dict[str, torch.Tensor]:
"""Prepare inputs with proper tokenization"""
# Keep room for generation
max_input_tokens = int(params.get("max_input_tokens", max(self.max_context - max_new_tokens - 8, 256)))
model_inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens,
)
if torch.cuda.is_available():
model_inputs = {k: v.to(self.model.device) for k, v in model_inputs.items()}
return model_inputs
def _stopping(self, params: Dict[str, Any]):
"""Create stopping criteria"""
from transformers import StoppingCriteria, StoppingCriteriaList
class StopOnSequences(StoppingCriteria):
def __init__(self, stop_sequences: List[List[int]]):
super().__init__()
self.stop_sequences = [torch.tensor(x, dtype=torch.long) for x in stop_sequences if len(x) > 0]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids.shape[0] == 0 or not self.stop_sequences:
return False
generated = input_ids[0]
for seq in self.stop_sequences:
if generated.shape[0] >= seq.shape[0] and torch.equal(generated[-seq.shape[0]:], seq.to(generated.device)):
return True
return False
stop = params.get("stop", [])
if isinstance(stop, str):
stop = [stop]
if not isinstance(stop, list):
stop = []
stop_ids = [self.tokenizer.encode(s, add_special_tokens=False) for s in stop]
criteria = []
if stop_ids:
criteria.append(StopOnSequences(stop_ids))
return StoppingCriteriaList(criteria)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle inference requests with proper error handling"""
try:
params = data.get("parameters", {}) or {}
# Set seed if provided
seed = params.get("seed")
if seed is not None:
try:
torch.manual_seed(int(seed))
except (ValueError, TypeError):
pass
# Generation parameters
max_new_tokens = int(params.get("max_new_tokens", 512))
temperature = float(params.get("temperature", 0.2))
top_p = float(params.get("top_p", 0.9))
top_k = int(params.get("top_k", 50))
repetition_penalty = float(params.get("repetition_penalty", 1.05))
num_beams = int(params.get("num_beams", 1))
do_sample = bool(params.get("do_sample", temperature > 0 and num_beams == 1))
# Build prompt
prompt = self._build_prompt(data)
model_inputs = self._prepare_inputs(prompt, max_new_tokens, params)
input_length = model_inputs["input_ids"].shape[-1]
# Generation kwargs
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=max(0.0, temperature),
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
stopping_criteria=self._stopping(params),
)
# Generate
with torch.no_grad():
output_ids = self.model.generate(**model_inputs, **gen_kwargs)
# Slice off the prompt
gen_ids = output_ids[0][input_length:]
text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
# Apply text-side stop strings if provided
stop = params.get("stop", [])
if isinstance(stop, str):
stop = [stop]
for s in stop or []:
idx = text.find(s)
if idx != -1:
text = text[:idx]
break
# Token accounting
prompt_tokens = int(input_length)
completion_tokens = int(gen_ids.shape[-1])
total_tokens = prompt_tokens + completion_tokens
return {
"generated_text": text,
"input_tokens": prompt_tokens,
"generated_tokens": completion_tokens,
"total_tokens": total_tokens,
"params": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"num_beams": num_beams,
"do_sample": do_sample,
},
}
except Exception as e:
logger.error(f"Generation error: {str(e)}")
return {
"error": f"Generation failed: {str(e)}",
"generated_text": "",
"input_tokens": 0,
"generated_tokens": 0,
"total_tokens": 0
}