|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
import json |
|
|
import os |
|
|
import glob |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize handler with robust file discovery |
|
|
""" |
|
|
logger.info(f"Loading model from {path}") |
|
|
|
|
|
try: |
|
|
|
|
|
if os.path.exists(path): |
|
|
contents = os.listdir(path) |
|
|
logger.info(f"Repository contents: {contents}") |
|
|
|
|
|
|
|
|
for item in contents: |
|
|
item_path = os.path.join(path, item) |
|
|
if os.path.isdir(item_path): |
|
|
sub_contents = os.listdir(item_path) |
|
|
logger.info(f"Directory {item}: {sub_contents}") |
|
|
|
|
|
|
|
|
model_path = self._find_model_path(path) |
|
|
logger.info(f"Using model path: {model_path}") |
|
|
|
|
|
|
|
|
self.tokenizer = self._load_tokenizer(model_path, path) |
|
|
logger.info("Tokenizer loaded successfully") |
|
|
|
|
|
|
|
|
self.model = self._load_model(model_path, path) |
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize: {str(e)}") |
|
|
raise e |
|
|
|
|
|
def _find_model_path(self, base_path: str) -> str: |
|
|
"""Find the actual path containing model files""" |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(base_path, "config.json")): |
|
|
return base_path |
|
|
|
|
|
|
|
|
hf_path = os.path.join(base_path, "models", "huggingface") |
|
|
if os.path.exists(hf_path) and os.path.exists(os.path.join(hf_path, "config.json")): |
|
|
return hf_path |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(base_path): |
|
|
if "config.json" in files: |
|
|
return root |
|
|
|
|
|
|
|
|
return base_path |
|
|
|
|
|
def _load_tokenizer(self, model_path: str, base_path: str): |
|
|
"""Load tokenizer with fallback methods""" |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Trying to load tokenizer from {model_path}") |
|
|
return AutoTokenizer.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
except Exception as e1: |
|
|
logger.warning(f"Failed to load from {model_path}: {e1}") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Trying to load tokenizer from {base_path}") |
|
|
return AutoTokenizer.from_pretrained( |
|
|
base_path, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
except Exception as e2: |
|
|
logger.warning(f"Failed to load from {base_path}: {e2}") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("Using fallback tokenizer from Qwen2-7B-Instruct") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"Qwen/Qwen2-7B-Instruct", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
return tokenizer |
|
|
|
|
|
except Exception as e3: |
|
|
logger.error(f"All tokenizer loading methods failed: {e3}") |
|
|
raise e3 |
|
|
|
|
|
def _load_model(self, model_path: str, base_path: str): |
|
|
"""Load model with fallback methods""" |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Trying to load model from {model_path}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
local_files_only=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
except Exception as e1: |
|
|
logger.warning(f"Failed to load from {model_path}: {e1}") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Trying to load model from {base_path}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
local_files_only=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
except Exception as e2: |
|
|
logger.error(f"Model loading failed from both paths: {e2}") |
|
|
raise e2 |
|
|
|
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Handle inference requests |
|
|
""" |
|
|
try: |
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if not inputs: |
|
|
return [{"error": "No input provided", "generated_text": ""}] |
|
|
|
|
|
|
|
|
max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024) |
|
|
temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0)) |
|
|
top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0)) |
|
|
do_sample = parameters.get("do_sample", True) |
|
|
|
|
|
|
|
|
if inputs.startswith("<|im_start|>"): |
|
|
formatted_input = inputs |
|
|
else: |
|
|
formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
|
|
|
try: |
|
|
input_ids = self.tokenizer.encode( |
|
|
formatted_input, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=3072 |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Tokenization failed: {e}") |
|
|
return [{"error": f"Tokenization failed: {str(e)}", "generated_text": ""}] |
|
|
|
|
|
if input_ids.size(1) == 0: |
|
|
return [{"error": "Empty input after tokenization", "generated_text": ""}] |
|
|
|
|
|
|
|
|
input_ids = input_ids.to(next(self.model.parameters()).device) |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=do_sample, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
use_cache=True, |
|
|
num_return_sequences=1 |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Generation failed: {e}") |
|
|
return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}] |
|
|
|
|
|
|
|
|
try: |
|
|
generated_ids = outputs[0][input_ids.size(1):] |
|
|
response = self.tokenizer.decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
|
|
|
response = response.replace("<|im_end|>", "").strip() |
|
|
|
|
|
return [{ |
|
|
"generated_text": response, |
|
|
"generated_tokens": len(generated_ids), |
|
|
"finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length" |
|
|
}] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Decoding failed: {e}") |
|
|
return [{"error": f"Decoding failed: {str(e)}", "generated_text": ""}] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference error: {str(e)}") |
|
|
return [{"error": f"Inference failed: {str(e)}", "generated_text": ""}] |