TomBombadyl commited on
Commit
09729dc
·
verified ·
1 Parent(s): 0b620bc

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +123 -0
handler.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import logging
5
+
6
+ # Set up logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path: str = ""):
12
+ """
13
+ Initialize the handler for Qwen2.5-Coder-7B-Instruct-Omni1.1
14
+ Optimized for Isaac Sim robotics code generation
15
+ """
16
+ logger.info(f"Loading model from {path}")
17
+
18
+ # Load tokenizer with proper configuration
19
+ self.tokenizer = AutoTokenizer.from_pretrained(
20
+ path,
21
+ trust_remote_code=True,
22
+ use_fast=False # Use slow tokenizer to avoid tokenizer.json issues
23
+ )
24
+
25
+ # Set pad token if not present
26
+ if self.tokenizer.pad_token is None:
27
+ self.tokenizer.pad_token = self.tokenizer.eos_token
28
+
29
+ # Load model with optimizations for inference
30
+ self.model = AutoModelForCausalLM.from_pretrained(
31
+ path,
32
+ torch_dtype=torch.float16,
33
+ device_map="auto",
34
+ trust_remote_code=True,
35
+ low_cpu_mem_usage=True,
36
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
37
+ )
38
+
39
+ # Set model to evaluation mode
40
+ self.model.eval()
41
+
42
+ logger.info("Model loaded successfully")
43
+
44
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
45
+ """
46
+ Handle inference requests
47
+
48
+ Expected input format:
49
+ {
50
+ "inputs": "Your prompt here",
51
+ "parameters": {
52
+ "max_new_tokens": 512,
53
+ "temperature": 0.7,
54
+ "top_p": 0.9,
55
+ "do_sample": true,
56
+ "repetition_penalty": 1.1
57
+ }
58
+ }
59
+ """
60
+ try:
61
+ # Extract inputs and parameters
62
+ inputs = data.get("inputs", "")
63
+ parameters = data.get("parameters", {})
64
+
65
+ # Default generation parameters optimized for code generation
66
+ max_new_tokens = parameters.get("max_new_tokens", 512)
67
+ temperature = parameters.get("temperature", 0.7)
68
+ top_p = parameters.get("top_p", 0.9)
69
+ do_sample = parameters.get("do_sample", True)
70
+ repetition_penalty = parameters.get("repetition_penalty", 1.1)
71
+
72
+ # Format input with proper chat template for Qwen2.5
73
+ if not inputs.startswith("<|im_start|>"):
74
+ formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant"
75
+ else:
76
+ formatted_input = inputs
77
+
78
+ # Tokenize input
79
+ input_ids = self.tokenizer.encode(
80
+ formatted_input,
81
+ return_tensors="pt",
82
+ truncation=True,
83
+ max_length=2048 # Leave room for generation
84
+ ).to(self.model.device)
85
+
86
+ # Generate response
87
+ with torch.no_grad():
88
+ output_ids = self.model.generate(
89
+ input_ids,
90
+ max_new_tokens=max_new_tokens,
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ do_sample=do_sample,
94
+ repetition_penalty=repetition_penalty,
95
+ pad_token_id=self.tokenizer.pad_token_id,
96
+ eos_token_id=self.tokenizer.eos_token_id,
97
+ use_cache=True
98
+ )
99
+
100
+ # Decode only the new tokens (response)
101
+ response_ids = output_ids[0][input_ids.shape[1]:]
102
+ response_text = self.tokenizer.decode(
103
+ response_ids,
104
+ skip_special_tokens=True,
105
+ clean_up_tokenization_spaces=True
106
+ )
107
+
108
+ # Clean up response
109
+ response_text = response_text.strip()
110
+
111
+ # Return in expected format
112
+ return [{
113
+ "generated_text": response_text,
114
+ "generated_tokens": len(response_ids),
115
+ "finish_reason": "stop" if self.tokenizer.eos_token_id in response_ids else "length"
116
+ }]
117
+
118
+ except Exception as e:
119
+ logger.error(f"Error during inference: {str(e)}")
120
+ return [{
121
+ "error": f"Inference failed: {str(e)}",
122
+ "generated_text": ""
123
+ }]