from typing import Dict, List, Any import torch import json import os import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str = ""): """ Initialize handler using CTransformers format for memory efficiency """ logger.info(f"Loading model from {path}") try: # Use CTransformers format for lower memory usage ctransformers_path = os.path.join(path, "models", "ctransformers") if not os.path.exists(ctransformers_path): logger.warning(f"CTransformers path not found: {ctransformers_path}") logger.info("Falling back to HuggingFace format") ctransformers_path = path logger.info(f"Using model path: {ctransformers_path}") # Load components using the working handler approach self.tokenizer = self._load_tokenizer(ctransformers_path) self.model = self._load_model(ctransformers_path) logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to initialize: {str(e)}") raise e def _load_tokenizer(self, model_path: str): """Load tokenizer using AutoTokenizer""" logger.info("Loading tokenizer...") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_fast=True, ) # Ensure special tokens are set if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id logger.info("Tokenizer loaded successfully") return tokenizer def _load_model(self, model_path: str): """Load model using AutoModelForCausalLM with memory optimization""" logger.info("Loading model with memory optimization...") from transformers import AutoModelForCausalLM # Check GPU availability if torch.cuda.is_available(): logger.info(f"CUDA available: {torch.cuda.get_device_name()}") logger.info(f"GPU memory total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB") else: logger.warning("CUDA not available, using CPU") # Memory optimization settings device_map = "auto" if torch.cuda.is_available() else None gpu_mem = os.environ.get("GPU_MAX_MEM", "10GiB") # Conservative for 12GB limit cpu_mem = os.environ.get("CPU_MAX_MEM", "24GiB") max_memory = {0: gpu_mem, "cpu": cpu_mem} if torch.cuda.is_available() else None # Offload folder for memory management offload_folder = os.environ.get("OFFLOAD_FOLDER", "/tmp/hf-offload") try: os.makedirs(offload_folder, exist_ok=True) except OSError: offload_folder = "/tmp/hf-offload" os.makedirs(offload_folder, exist_ok=True) # Try to load with quantization first, fallback without if it fails model = None quantization_config = None # Attempt 1: Try with 8-bit quantization (if bitsandbytes is available) if torch.cuda.is_available(): try: # Check if bitsandbytes is available import bitsandbytes from transformers import BitsAndBytesConfig logger.info("bitsandbytes available, attempting 8-bit quantization...") bnb_config = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, device_map=device_map, quantization_config=bnb_config, low_cpu_mem_usage=True, offload_folder=offload_folder, max_memory=max_memory, ) logger.info("Successfully loaded with 8-bit quantization") quantization_config = "8-bit" except ImportError as e: logger.info(f"bitsandbytes not available ({str(e)}), skipping quantization...") model = None except Exception as e: logger.warning(f"8-bit quantization failed: {str(e)}") logger.info("Falling back to FP16 without quantization...") model = None # Attempt 2: Fallback to FP16 without quantization if model is None: try: model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, device_map=device_map, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, quantization_config=None, # Disable model's built-in quantization low_cpu_mem_usage=True, offload_folder=offload_folder if device_map == "auto" else None, max_memory=max_memory, ) logger.info("Successfully loaded with FP16 (no quantization)") quantization_config = "fp16" except Exception as e: logger.warning(f"FP16 loading failed: {str(e)}") logger.info("Falling back to FP32 CPU loading...") model = None # Attempt 3: Final fallback to CPU FP32 if model is None: try: model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.float32, quantization_config=None, # Disable model's built-in quantization low_cpu_mem_usage=True, ) logger.info("Successfully loaded with FP32 on CPU") quantization_config = "fp32_cpu" except Exception as e: logger.error(f"All loading attempts failed: {str(e)}") raise e if model is None: raise RuntimeError("Failed to load model with any configuration") model.eval() # Set context window self.max_context = getattr(model.config, "max_position_embeddings", None) or getattr(self.tokenizer, "model_max_length", 4096) if self.max_context is None or self.max_context == int(1e30): self.max_context = 4096 # Set token IDs self.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id self.eos_token_id = self.tokenizer.eos_token_id logger.info(f"Model loaded successfully with {quantization_config} configuration") return model def _build_prompt(self, data: Dict[str, Any]) -> str: """Build prompt using chat template or direct input""" # Accept either "messages" (chat) or "inputs"/"prompt" (single-turn) if "messages" in data and isinstance(data["messages"], list): return self.tokenizer.apply_chat_template( data["messages"], tokenize=False, add_generation_prompt=True ) user_text = data.get("inputs") or data.get("prompt") or "" if isinstance(user_text, str): messages = [{"role": "user", "content": user_text}] return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return str(user_text) def _prepare_inputs(self, prompt: str, max_new_tokens: int, params: Dict[str, Any]) -> Dict[str, torch.Tensor]: """Prepare inputs with proper tokenization""" # Keep room for generation max_input_tokens = int(params.get("max_input_tokens", max(self.max_context - max_new_tokens - 8, 256))) model_inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=max_input_tokens, ) if torch.cuda.is_available(): model_inputs = {k: v.to(self.model.device) for k, v in model_inputs.items()} return model_inputs def _stopping(self, params: Dict[str, Any]): """Create stopping criteria""" from transformers import StoppingCriteria, StoppingCriteriaList class StopOnSequences(StoppingCriteria): def __init__(self, stop_sequences: List[List[int]]): super().__init__() self.stop_sequences = [torch.tensor(x, dtype=torch.long) for x in stop_sequences if len(x) > 0] def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if input_ids.shape[0] == 0 or not self.stop_sequences: return False generated = input_ids[0] for seq in self.stop_sequences: if generated.shape[0] >= seq.shape[0] and torch.equal(generated[-seq.shape[0]:], seq.to(generated.device)): return True return False stop = params.get("stop", []) if isinstance(stop, str): stop = [stop] if not isinstance(stop, list): stop = [] stop_ids = [self.tokenizer.encode(s, add_special_tokens=False) for s in stop] criteria = [] if stop_ids: criteria.append(StopOnSequences(stop_ids)) return StoppingCriteriaList(criteria) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Handle inference requests with proper error handling""" try: params = data.get("parameters", {}) or {} # Set seed if provided seed = params.get("seed") if seed is not None: try: torch.manual_seed(int(seed)) except (ValueError, TypeError): pass # Generation parameters max_new_tokens = int(params.get("max_new_tokens", 512)) temperature = float(params.get("temperature", 0.2)) top_p = float(params.get("top_p", 0.9)) top_k = int(params.get("top_k", 50)) repetition_penalty = float(params.get("repetition_penalty", 1.05)) num_beams = int(params.get("num_beams", 1)) do_sample = bool(params.get("do_sample", temperature > 0 and num_beams == 1)) # Build prompt prompt = self._build_prompt(data) model_inputs = self._prepare_inputs(prompt, max_new_tokens, params) input_length = model_inputs["input_ids"].shape[-1] # Generation kwargs gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=max(0.0, temperature), top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, num_beams=num_beams, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, stopping_criteria=self._stopping(params), ) # Generate with torch.no_grad(): output_ids = self.model.generate(**model_inputs, **gen_kwargs) # Slice off the prompt gen_ids = output_ids[0][input_length:] text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) # Apply text-side stop strings if provided stop = params.get("stop", []) if isinstance(stop, str): stop = [stop] for s in stop or []: idx = text.find(s) if idx != -1: text = text[:idx] break # Token accounting prompt_tokens = int(input_length) completion_tokens = int(gen_ids.shape[-1]) total_tokens = prompt_tokens + completion_tokens return { "generated_text": text, "input_tokens": prompt_tokens, "generated_tokens": completion_tokens, "total_tokens": total_tokens, "params": { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "num_beams": num_beams, "do_sample": do_sample, }, } except Exception as e: logger.error(f"Generation error: {str(e)}") return { "error": f"Generation failed: {str(e)}", "generated_text": "", "input_tokens": 0, "generated_tokens": 0, "total_tokens": 0 }