from typing import Dict, List, Any import torch import json import os import glob from transformers import AutoTokenizer, AutoModelForCausalLM import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str = ""): """ Initialize handler with robust file discovery """ logger.info(f"Loading model from {path}") try: # Log directory contents to understand structure if os.path.exists(path): contents = os.listdir(path) logger.info(f"Repository contents: {contents}") # Look for model files in subdirectories for item in contents: item_path = os.path.join(path, item) if os.path.isdir(item_path): sub_contents = os.listdir(item_path) logger.info(f"Directory {item}: {sub_contents}") # Try to find the actual model path model_path = self._find_model_path(path) logger.info(f"Using model path: {model_path}") # Load tokenizer - try multiple approaches self.tokenizer = self._load_tokenizer(model_path, path) logger.info("Tokenizer loaded successfully") # Load model self.model = self._load_model(model_path, path) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to initialize: {str(e)}") raise e def _find_model_path(self, base_path: str) -> str: """Find the actual path containing model files""" # Check if config.json is in base path if os.path.exists(os.path.join(base_path, "config.json")): return base_path # Check models/huggingface subdirectory hf_path = os.path.join(base_path, "models", "huggingface") if os.path.exists(hf_path) and os.path.exists(os.path.join(hf_path, "config.json")): return hf_path # Check for any subdirectory with config.json for root, dirs, files in os.walk(base_path): if "config.json" in files: return root # Fallback to base path return base_path def _load_tokenizer(self, model_path: str, base_path: str): """Load tokenizer with fallback methods""" try: # Try direct loading from model path logger.info(f"Trying to load tokenizer from {model_path}") return AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=True ) except Exception as e1: logger.warning(f"Failed to load from {model_path}: {e1}") try: # Try loading from base path logger.info(f"Trying to load tokenizer from {base_path}") return AutoTokenizer.from_pretrained( base_path, trust_remote_code=True, local_files_only=True ) except Exception as e2: logger.warning(f"Failed to load from {base_path}: {e2}") try: # Try loading from Hugging Face Hub as fallback logger.info("Using fallback tokenizer from Qwen2-7B-Instruct") tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen2-7B-Instruct", trust_remote_code=True ) # Set special tokens tokenizer.pad_token = tokenizer.eos_token return tokenizer except Exception as e3: logger.error(f"All tokenizer loading methods failed: {e3}") raise e3 def _load_model(self, model_path: str, base_path: str): """Load model with 8-bit quantization to fit memory limits""" try: # Try direct loading from model path with 8-bit quantization logger.info(f"Trying to load model from {model_path} with 8-bit quantization") model = AutoModelForCausalLM.from_pretrained( model_path, load_in_8bit=True, # Use 8-bit quantization device_map="auto", trust_remote_code=True, local_files_only=True, low_cpu_mem_usage=True ) except Exception as e1: logger.warning(f"Failed to load from {model_path}: {e1}") try: # Try loading from base path with 8-bit quantization logger.info(f"Trying to load model from {base_path} with 8-bit quantization") model = AutoModelForCausalLM.from_pretrained( base_path, load_in_8bit=True, # Use 8-bit quantization device_map="auto", trust_remote_code=True, local_files_only=True, low_cpu_mem_usage=True ) except Exception as e2: logger.error(f"Model loading failed from both paths: {e2}") raise e2 model.eval() return model def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handle inference requests """ try: inputs = data.get("inputs", "") parameters = data.get("parameters", {}) if not inputs: return [{"error": "No input provided", "generated_text": ""}] # Generation parameters with safety limits max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024) temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0)) top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0)) do_sample = parameters.get("do_sample", True) # Format input for Qwen chat template if inputs.startswith("<|im_start|>"): formatted_input = inputs else: formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n" # Tokenize with error handling try: input_ids = self.tokenizer.encode( formatted_input, return_tensors="pt", truncation=True, max_length=3072 ) except Exception as e: logger.error(f"Tokenization failed: {e}") return [{"error": f"Tokenization failed: {str(e)}", "generated_text": ""}] if input_ids.size(1) == 0: return [{"error": "Empty input after tokenization", "generated_text": ""}] # Move to model device input_ids = input_ids.to(next(self.model.parameters()).device) # Generate with error handling try: with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, use_cache=True, num_return_sequences=1 ) except Exception as e: logger.error(f"Generation failed: {e}") return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}] # Decode response try: generated_ids = outputs[0][input_ids.size(1):] response = self.tokenizer.decode( generated_ids, skip_special_tokens=True ).strip() # Clean up response response = response.replace("<|im_end|>", "").strip() return [{ "generated_text": response, "generated_tokens": len(generated_ids), "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length" }] except Exception as e: logger.error(f"Decoding failed: {e}") return [{"error": f"Decoding failed: {str(e)}", "generated_text": ""}] except Exception as e: logger.error(f"Inference error: {str(e)}") return [{"error": f"Inference failed: {str(e)}", "generated_text": ""}]