TomBombadyl commited on
Commit
b9607cb
·
verified ·
1 Parent(s): 7c85042

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +73 -63
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Any, Optional
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import logging
@@ -11,113 +11,123 @@ class EndpointHandler:
11
  def __init__(self, path: str = ""):
12
  """
13
  Initialize the handler for Qwen2.5-Coder-7B-Instruct-Omni1.1
14
- Optimized for Isaac Sim robotics code generation
15
  """
16
  logger.info(f"Loading model from {path}")
17
 
18
- # Load tokenizer with proper configuration
19
- self.tokenizer = AutoTokenizer.from_pretrained(
20
- path,
21
- trust_remote_code=True,
22
- use_fast=False # Use slow tokenizer to avoid tokenizer.json issues
23
- )
24
-
25
- # Set pad token if not present
26
- if self.tokenizer.pad_token is None:
27
- self.tokenizer.pad_token = self.tokenizer.eos_token
28
-
29
- # Load model with optimizations for inference
30
- self.model = AutoModelForCausalLM.from_pretrained(
31
- path,
32
- torch_dtype=torch.float16,
33
- device_map="auto",
34
- trust_remote_code=True,
35
- low_cpu_mem_usage=True,
36
- attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
37
- )
38
-
39
- # Set model to evaluation mode
40
- self.model.eval()
41
-
42
- logger.info("Model loaded successfully")
 
 
 
 
 
 
 
 
43
 
44
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
45
  """
46
  Handle inference requests
47
-
48
- Expected input format:
49
- {
50
- "inputs": "Your prompt here",
51
- "parameters": {
52
- "max_new_tokens": 512,
53
- "temperature": 0.7,
54
- "top_p": 0.9,
55
- "do_sample": true,
56
- "repetition_penalty": 1.1
57
- }
58
- }
59
  """
60
  try:
61
  # Extract inputs and parameters
62
  inputs = data.get("inputs", "")
63
  parameters = data.get("parameters", {})
64
 
65
- # Default generation parameters optimized for code generation
 
 
 
66
  max_new_tokens = parameters.get("max_new_tokens", 512)
67
  temperature = parameters.get("temperature", 0.7)
68
  top_p = parameters.get("top_p", 0.9)
69
  do_sample = parameters.get("do_sample", True)
70
  repetition_penalty = parameters.get("repetition_penalty", 1.1)
71
 
72
- # Format input with proper chat template for Qwen2.5
73
- if not inputs.startswith("<|im_start|>"):
74
- formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant"
75
- else:
76
  formatted_input = inputs
 
 
 
77
 
78
- # Tokenize input
79
  input_ids = self.tokenizer.encode(
80
  formatted_input,
81
  return_tensors="pt",
 
82
  truncation=True,
83
- max_length=2048 # Leave room for generation
84
- ).to(self.model.device)
 
 
 
 
 
85
 
86
- # Generate response
87
  with torch.no_grad():
88
- output_ids = self.model.generate(
89
  input_ids,
90
  max_new_tokens=max_new_tokens,
91
  temperature=temperature,
92
  top_p=top_p,
93
  do_sample=do_sample,
94
  repetition_penalty=repetition_penalty,
95
- pad_token_id=self.tokenizer.pad_token_id,
96
  eos_token_id=self.tokenizer.eos_token_id,
97
- use_cache=True
 
98
  )
99
 
100
- # Decode only the new tokens (response)
101
- response_ids = output_ids[0][input_ids.shape[1]:]
102
- response_text = self.tokenizer.decode(
103
- response_ids,
104
  skip_special_tokens=True,
105
  clean_up_tokenization_spaces=True
106
  )
107
 
108
  # Clean up response
109
- response_text = response_text.strip()
 
 
 
110
 
111
- # Return in expected format
112
  return [{
113
- "generated_text": response_text,
114
- "generated_tokens": len(response_ids),
115
- "finish_reason": "stop" if self.tokenizer.eos_token_id in response_ids else "length"
116
  }]
117
 
118
  except Exception as e:
119
- logger.error(f"Error during inference: {str(e)}")
120
  return [{
121
- "error": f"Inference failed: {str(e)}",
122
  "generated_text": ""
123
- }]
 
1
+ from typing import Dict, List, Any
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import logging
 
11
  def __init__(self, path: str = ""):
12
  """
13
  Initialize the handler for Qwen2.5-Coder-7B-Instruct-Omni1.1
14
+ Simple and robust implementation
15
  """
16
  logger.info(f"Loading model from {path}")
17
 
18
+ try:
19
+ # Load tokenizer - most robust approach
20
+ self.tokenizer = AutoTokenizer.from_pretrained(
21
+ path,
22
+ trust_remote_code=True,
23
+ use_fast=False,
24
+ padding_side="left"
25
+ )
26
+
27
+ # Ensure we have proper tokens
28
+ if self.tokenizer.pad_token is None:
29
+ self.tokenizer.pad_token = self.tokenizer.eos_token
30
+
31
+ if self.tokenizer.chat_template is None:
32
+ # Set a basic chat template for Qwen
33
+ self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
34
+
35
+ # Load model
36
+ self.model = AutoModelForCausalLM.from_pretrained(
37
+ path,
38
+ torch_dtype=torch.float16,
39
+ device_map="auto",
40
+ trust_remote_code=True,
41
+ low_cpu_mem_usage=True,
42
+ use_cache=True
43
+ )
44
+
45
+ self.model.eval()
46
+ logger.info("Model and tokenizer loaded successfully")
47
+
48
+ except Exception as e:
49
+ logger.error(f"Error loading model: {str(e)}")
50
+ raise e
51
 
52
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
53
  """
54
  Handle inference requests
 
 
 
 
 
 
 
 
 
 
 
 
55
  """
56
  try:
57
  # Extract inputs and parameters
58
  inputs = data.get("inputs", "")
59
  parameters = data.get("parameters", {})
60
 
61
+ if not inputs:
62
+ return [{"error": "No input provided", "generated_text": ""}]
63
+
64
+ # Generation parameters
65
  max_new_tokens = parameters.get("max_new_tokens", 512)
66
  temperature = parameters.get("temperature", 0.7)
67
  top_p = parameters.get("top_p", 0.9)
68
  do_sample = parameters.get("do_sample", True)
69
  repetition_penalty = parameters.get("repetition_penalty", 1.1)
70
 
71
+ # Prepare input - handle both raw text and pre-formatted chat
72
+ if inputs.startswith("<|im_start|>"):
73
+ # Already formatted
 
74
  formatted_input = inputs
75
+ else:
76
+ # Format as chat
77
+ formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
78
 
79
+ # Tokenize
80
  input_ids = self.tokenizer.encode(
81
  formatted_input,
82
  return_tensors="pt",
83
+ add_special_tokens=False,
84
  truncation=True,
85
+ max_length=4096 - max_new_tokens # Leave room for generation
86
+ )
87
+
88
+ if input_ids.size(1) == 0:
89
+ return [{"error": "Input tokenization failed", "generated_text": ""}]
90
+
91
+ input_ids = input_ids.to(self.model.device)
92
 
93
+ # Generate
94
  with torch.no_grad():
95
+ outputs = self.model.generate(
96
  input_ids,
97
  max_new_tokens=max_new_tokens,
98
  temperature=temperature,
99
  top_p=top_p,
100
  do_sample=do_sample,
101
  repetition_penalty=repetition_penalty,
102
+ pad_token_id=self.tokenizer.eos_token_id,
103
  eos_token_id=self.tokenizer.eos_token_id,
104
+ use_cache=True,
105
+ num_return_sequences=1
106
  )
107
 
108
+ # Decode response (only new tokens)
109
+ generated_ids = outputs[0][input_ids.size(1):]
110
+ response = self.tokenizer.decode(
111
+ generated_ids,
112
  skip_special_tokens=True,
113
  clean_up_tokenization_spaces=True
114
  )
115
 
116
  # Clean up response
117
+ response = response.strip()
118
+
119
+ # Remove any remaining special tokens manually
120
+ response = response.replace("<|im_end|>", "").strip()
121
 
 
122
  return [{
123
+ "generated_text": response,
124
+ "generated_tokens": len(generated_ids),
125
+ "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
126
  }]
127
 
128
  except Exception as e:
129
+ logger.error(f"Generation error: {str(e)}")
130
  return [{
131
+ "error": f"Generation failed: {str(e)}",
132
  "generated_text": ""
133
+ }]