TomBombadyl commited on
Commit
33d45b3
·
verified ·
1 Parent(s): 3cfb14a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +176 -119
handler.py CHANGED
@@ -2,8 +2,8 @@ from typing import Dict, List, Any
2
  import torch
3
  import json
4
  import os
5
- from transformers import PreTrainedTokenizerFast, PreTrainedModel
6
- from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
7
  import logging
8
 
9
  # Set up logging
@@ -13,100 +13,135 @@ logger = logging.getLogger(__name__)
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
  """
16
- Initialize handler with manual model loading to bypass auto-detection
17
  """
18
  logger.info(f"Loading model from {path}")
19
 
20
  try:
21
- # Manual config loading and creation
22
- config_path = os.path.join(path, "config.json")
23
-
24
- if os.path.exists(config_path):
25
- with open(config_path, 'r') as f:
26
- config_dict = json.load(f)
27
- logger.info(f"Loaded config: {config_dict.get('model_type', 'UNKNOWN')}")
28
-
29
- # Create Qwen2Config manually
30
- config = Qwen2Config(**config_dict)
31
- else:
32
- logger.warning("No config.json found, using default Qwen2Config")
33
- config = Qwen2Config()
34
-
35
- # Load tokenizer manually without auto-detection
36
- logger.info("Loading tokenizer manually...")
37
- tokenizer_path = os.path.join(path, "tokenizer.json")
38
-
39
- if os.path.exists(tokenizer_path):
40
- # Load tokenizer from tokenizer.json directly
41
- self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
42
- else:
43
- # Try loading from vocab files
44
- vocab_path = os.path.join(path, "vocab.json")
45
- merges_path = os.path.join(path, "merges.txt")
46
-
47
- if os.path.exists(vocab_path):
48
- self.tokenizer = PreTrainedTokenizerFast(
49
- tokenizer_file=None,
50
- vocab_file=vocab_path,
51
- merges_file=merges_path if os.path.exists(merges_path) else None
52
- )
53
- else:
54
- # Fallback: create basic tokenizer
55
- from transformers import AutoTokenizer
56
- logger.warning("Using fallback tokenizer loading...")
57
- self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
58
-
59
- # Set special tokens
60
- if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
61
- self.tokenizer.pad_token = "<|endoftext|>"
62
- self.tokenizer.pad_token_id = 151643
63
-
64
- if not hasattr(self.tokenizer, 'eos_token') or self.tokenizer.eos_token is None:
65
- self.tokenizer.eos_token = "<|endoftext|>"
66
- self.tokenizer.eos_token_id = 151643
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  logger.info("Tokenizer loaded successfully")
69
 
70
- # Load model manually with the config
71
- logger.info("Loading model manually...")
72
- self.model = Qwen2ForCausalLM(config)
73
-
74
- # Load state dict manually
75
- safetensors_files = [f for f in os.listdir(path) if f.endswith('.safetensors')]
76
-
77
- if safetensors_files:
78
- logger.info(f"Loading weights from {len(safetensors_files)} safetensors files")
79
- from safetensors.torch import load_file
80
-
81
- state_dict = {}
82
- for file in sorted(safetensors_files):
83
- file_path = os.path.join(path, file)
84
- partial_state_dict = load_file(file_path)
85
- state_dict.update(partial_state_dict)
86
-
87
- # Load the state dict
88
- missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
89
-
90
- if missing_keys:
91
- logger.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5
92
- if unexpected_keys:
93
- logger.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5
94
- else:
95
- logger.error("No safetensors files found!")
96
- raise FileNotFoundError("No model weights found")
97
-
98
- # Move to GPU and set to eval mode
99
- self.model = self.model.half() # Convert to float16
100
- if torch.cuda.is_available():
101
- self.model = self.model.cuda()
102
- self.model.eval()
103
-
104
  logger.info("Model loaded successfully")
105
 
106
  except Exception as e:
107
- logger.error(f"Failed to load model: {str(e)}")
108
  raise e
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
111
  """
112
  Handle inference requests
@@ -118,53 +153,75 @@ class EndpointHandler:
118
  if not inputs:
119
  return [{"error": "No input provided", "generated_text": ""}]
120
 
121
- # Generation parameters
122
  max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
123
  temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
124
  top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
125
  do_sample = parameters.get("do_sample", True)
126
 
127
- # Format input
128
  if inputs.startswith("<|im_start|>"):
129
  formatted_input = inputs
130
  else:
131
  formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
132
 
133
- # Tokenize
134
- input_ids = self.tokenizer.encode(
135
- formatted_input,
136
- return_tensors="pt",
137
- truncation=True,
138
- max_length=3072
139
- )
140
-
141
- if torch.cuda.is_available():
142
- input_ids = input_ids.cuda()
143
-
144
- # Generate
145
- with torch.no_grad():
146
- outputs = self.model.generate(
147
- input_ids,
148
- max_new_tokens=max_new_tokens,
149
- temperature=temperature,
150
- top_p=top_p,
151
- do_sample=do_sample,
152
- pad_token_id=self.tokenizer.pad_token_id,
153
- eos_token_id=self.tokenizer.eos_token_id,
154
- use_cache=True
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Decode response
158
- generated_ids = outputs[0][input_ids.size(1):]
159
- response = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
160
- response = response.replace("<|im_end|>", "").strip()
161
-
162
- return [{
163
- "generated_text": response,
164
- "generated_tokens": len(generated_ids),
165
- "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
166
- }]
 
 
 
 
 
 
 
 
 
 
167
 
168
  except Exception as e:
169
- logger.error(f"Generation error: {str(e)}")
170
- return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]
 
2
  import torch
3
  import json
4
  import os
5
+ import glob
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import logging
8
 
9
  # Set up logging
 
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
  """
16
+ Initialize handler with robust file discovery
17
  """
18
  logger.info(f"Loading model from {path}")
19
 
20
  try:
21
+ # Log directory contents to understand structure
22
+ if os.path.exists(path):
23
+ contents = os.listdir(path)
24
+ logger.info(f"Repository contents: {contents}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Look for model files in subdirectories
27
+ for item in contents:
28
+ item_path = os.path.join(path, item)
29
+ if os.path.isdir(item_path):
30
+ sub_contents = os.listdir(item_path)
31
+ logger.info(f"Directory {item}: {sub_contents}")
32
+
33
+ # Try to find the actual model path
34
+ model_path = self._find_model_path(path)
35
+ logger.info(f"Using model path: {model_path}")
36
+
37
+ # Load tokenizer - try multiple approaches
38
+ self.tokenizer = self._load_tokenizer(model_path, path)
39
  logger.info("Tokenizer loaded successfully")
40
 
41
+ # Load model
42
+ self.model = self._load_model(model_path, path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  logger.info("Model loaded successfully")
44
 
45
  except Exception as e:
46
+ logger.error(f"Failed to initialize: {str(e)}")
47
  raise e
48
 
49
+ def _find_model_path(self, base_path: str) -> str:
50
+ """Find the actual path containing model files"""
51
+
52
+ # Check if config.json is in base path
53
+ if os.path.exists(os.path.join(base_path, "config.json")):
54
+ return base_path
55
+
56
+ # Check models/huggingface subdirectory
57
+ hf_path = os.path.join(base_path, "models", "huggingface")
58
+ if os.path.exists(hf_path) and os.path.exists(os.path.join(hf_path, "config.json")):
59
+ return hf_path
60
+
61
+ # Check for any subdirectory with config.json
62
+ for root, dirs, files in os.walk(base_path):
63
+ if "config.json" in files:
64
+ return root
65
+
66
+ # Fallback to base path
67
+ return base_path
68
+
69
+ def _load_tokenizer(self, model_path: str, base_path: str):
70
+ """Load tokenizer with fallback methods"""
71
+
72
+ try:
73
+ # Try direct loading from model path
74
+ logger.info(f"Trying to load tokenizer from {model_path}")
75
+ return AutoTokenizer.from_pretrained(
76
+ model_path,
77
+ trust_remote_code=True,
78
+ local_files_only=True
79
+ )
80
+ except Exception as e1:
81
+ logger.warning(f"Failed to load from {model_path}: {e1}")
82
+
83
+ try:
84
+ # Try loading from base path
85
+ logger.info(f"Trying to load tokenizer from {base_path}")
86
+ return AutoTokenizer.from_pretrained(
87
+ base_path,
88
+ trust_remote_code=True,
89
+ local_files_only=True
90
+ )
91
+ except Exception as e2:
92
+ logger.warning(f"Failed to load from {base_path}: {e2}")
93
+
94
+ try:
95
+ # Try loading from Hugging Face Hub as fallback
96
+ logger.info("Using fallback tokenizer from Qwen2-7B-Instruct")
97
+ tokenizer = AutoTokenizer.from_pretrained(
98
+ "Qwen/Qwen2-7B-Instruct",
99
+ trust_remote_code=True
100
+ )
101
+
102
+ # Set special tokens
103
+ tokenizer.pad_token = tokenizer.eos_token
104
+ return tokenizer
105
+
106
+ except Exception as e3:
107
+ logger.error(f"All tokenizer loading methods failed: {e3}")
108
+ raise e3
109
+
110
+ def _load_model(self, model_path: str, base_path: str):
111
+ """Load model with fallback methods"""
112
+
113
+ try:
114
+ # Try direct loading from model path
115
+ logger.info(f"Trying to load model from {model_path}")
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ model_path,
118
+ torch_dtype=torch.float16,
119
+ device_map="auto",
120
+ trust_remote_code=True,
121
+ local_files_only=True,
122
+ low_cpu_mem_usage=True
123
+ )
124
+ except Exception as e1:
125
+ logger.warning(f"Failed to load from {model_path}: {e1}")
126
+
127
+ try:
128
+ # Try loading from base path
129
+ logger.info(f"Trying to load model from {base_path}")
130
+ model = AutoModelForCausalLM.from_pretrained(
131
+ base_path,
132
+ torch_dtype=torch.float16,
133
+ device_map="auto",
134
+ trust_remote_code=True,
135
+ local_files_only=True,
136
+ low_cpu_mem_usage=True
137
+ )
138
+ except Exception as e2:
139
+ logger.error(f"Model loading failed from both paths: {e2}")
140
+ raise e2
141
+
142
+ model.eval()
143
+ return model
144
+
145
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
146
  """
147
  Handle inference requests
 
153
  if not inputs:
154
  return [{"error": "No input provided", "generated_text": ""}]
155
 
156
+ # Generation parameters with safety limits
157
  max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
158
  temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
159
  top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
160
  do_sample = parameters.get("do_sample", True)
161
 
162
+ # Format input for Qwen chat template
163
  if inputs.startswith("<|im_start|>"):
164
  formatted_input = inputs
165
  else:
166
  formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
167
 
168
+ # Tokenize with error handling
169
+ try:
170
+ input_ids = self.tokenizer.encode(
171
+ formatted_input,
172
+ return_tensors="pt",
173
+ truncation=True,
174
+ max_length=3072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
+ except Exception as e:
177
+ logger.error(f"Tokenization failed: {e}")
178
+ return [{"error": f"Tokenization failed: {str(e)}", "generated_text": ""}]
179
+
180
+ if input_ids.size(1) == 0:
181
+ return [{"error": "Empty input after tokenization", "generated_text": ""}]
182
+
183
+ # Move to model device
184
+ input_ids = input_ids.to(next(self.model.parameters()).device)
185
+
186
+ # Generate with error handling
187
+ try:
188
+ with torch.no_grad():
189
+ outputs = self.model.generate(
190
+ input_ids,
191
+ max_new_tokens=max_new_tokens,
192
+ temperature=temperature,
193
+ top_p=top_p,
194
+ do_sample=do_sample,
195
+ pad_token_id=self.tokenizer.pad_token_id,
196
+ eos_token_id=self.tokenizer.eos_token_id,
197
+ use_cache=True,
198
+ num_return_sequences=1
199
+ )
200
+ except Exception as e:
201
+ logger.error(f"Generation failed: {e}")
202
+ return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]
203
 
204
  # Decode response
205
+ try:
206
+ generated_ids = outputs[0][input_ids.size(1):]
207
+ response = self.tokenizer.decode(
208
+ generated_ids,
209
+ skip_special_tokens=True
210
+ ).strip()
211
+
212
+ # Clean up response
213
+ response = response.replace("<|im_end|>", "").strip()
214
+
215
+ return [{
216
+ "generated_text": response,
217
+ "generated_tokens": len(generated_ids),
218
+ "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
219
+ }]
220
+
221
+ except Exception as e:
222
+ logger.error(f"Decoding failed: {e}")
223
+ return [{"error": f"Decoding failed: {str(e)}", "generated_text": ""}]
224
 
225
  except Exception as e:
226
+ logger.error(f"Inference error: {str(e)}")
227
+ return [{"error": f"Inference failed: {str(e)}", "generated_text": ""}]