TomBombadyl commited on
Commit
6a0c90b
·
verified ·
1 Parent(s): 1eff8e4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +154 -230
handler.py CHANGED
@@ -3,9 +3,7 @@ import torch
3
  import json
4
  import os
5
  import glob
6
- import tempfile
7
- from transformers import PreTrainedTokenizerFast, PreTrainedModel
8
- from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
9
  import logging
10
 
11
  # Set up logging
@@ -15,228 +13,131 @@ logger = logging.getLogger(__name__)
15
  class EndpointHandler:
16
  def __init__(self, path: str = ""):
17
  """
18
- Manual model loading that completely bypasses auto-detection
19
  """
20
  logger.info(f"Loading model from {path}")
21
 
22
  try:
23
- # Set cache directories to temp to avoid memory issues
24
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
25
- os.environ['HF_HOME'] = '/tmp/hf_home'
26
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
27
-
28
- # Clear any existing cache
29
- self._clear_cache()
30
-
31
- # Find the actual model files
32
- model_path = self._discover_model_files(path)
33
- logger.info(f"Model files found at: {model_path}")
34
-
35
- # Load config manually
36
- config = self._load_config_manually(model_path)
37
- logger.info(f"Config loaded: {config.model_type}")
38
-
39
- # Load tokenizer manually
40
- self.tokenizer = self._load_tokenizer_manually(model_path)
41
  logger.info("Tokenizer loaded successfully")
42
 
43
- # Create model architecture manually
44
- self.model = self._create_model_manually(config, model_path)
45
  logger.info("Model loaded successfully")
46
 
47
  except Exception as e:
48
  logger.error(f"Failed to initialize: {str(e)}")
49
  raise e
50
 
51
- def _clear_cache(self):
52
- """Clear any cached model data to free memory"""
53
- try:
54
- import shutil
55
- cache_dirs = ['/tmp/transformers_cache', '/tmp/hf_home']
56
- for cache_dir in cache_dirs:
57
- if os.path.exists(cache_dir):
58
- shutil.rmtree(cache_dir)
59
- logger.info(f"Cleared cache: {cache_dir}")
60
- except Exception as e:
61
- logger.warning(f"Could not clear cache: {e}")
62
-
63
- def _discover_model_files(self, base_path: str) -> str:
64
- """Find where the actual model files are located"""
65
-
66
- logger.info(f"Searching for model files in: {base_path}")
67
-
68
- # List all contents
69
- if os.path.exists(base_path):
70
- contents = os.listdir(base_path)
71
- logger.info(f"Base directory contents: {contents}")
72
-
73
- # Check for config.json in base path
74
- if "config.json" in contents:
75
- logger.info("Found config.json in base directory")
76
- return base_path
77
-
78
- # Check models subdirectories
79
- for item in contents:
80
- if os.path.isdir(os.path.join(base_path, item)):
81
- sub_path = os.path.join(base_path, item)
82
- sub_contents = os.listdir(sub_path)
83
- logger.info(f"Subdirectory {item}: {sub_contents}")
84
-
85
- if "config.json" in sub_contents:
86
- logger.info(f"Found config.json in {item} subdirectory")
87
- return sub_path
88
-
89
- # Search recursively
90
- for root, dirs, files in os.walk(base_path):
91
- if "config.json" in files:
92
- logger.info(f"Found config.json in {root}")
93
- return root
94
-
95
- raise FileNotFoundError(f"No config.json found in {base_path} or subdirectories")
96
-
97
- def _load_config_manually(self, model_path: str) -> Qwen2Config:
98
- """Load and create config manually"""
99
-
100
- config_path = os.path.join(model_path, "config.json")
101
- logger.info(f"Loading config from: {config_path}")
102
 
103
- with open(config_path, 'r') as f:
104
- config_dict = json.load(f)
 
105
 
106
- logger.info(f"Config keys: {list(config_dict.keys())}")
107
- logger.info(f"Model type: {config_dict.get('model_type', 'NOT_FOUND')}")
 
 
108
 
109
- # Ensure model_type is set correctly
110
- if 'model_type' not in config_dict:
111
- config_dict['model_type'] = 'qwen2'
112
- logger.info("Set model_type to 'qwen2'")
113
 
114
- # Create config object
115
- config = Qwen2Config(**config_dict)
116
- return config
117
 
118
- def _load_tokenizer_manually(self, model_path: str) -> PreTrainedTokenizerFast:
119
- """Load tokenizer without auto-detection"""
120
 
121
- # Look for tokenizer files
122
- tokenizer_files = []
123
- for file in os.listdir(model_path):
124
- if file in ['tokenizer.json', 'tokenizer_config.json', 'vocab.json']:
125
- tokenizer_files.append(file)
126
-
127
- logger.info(f"Found tokenizer files: {tokenizer_files}")
128
-
129
- if 'tokenizer.json' in tokenizer_files:
130
- # Load from tokenizer.json
131
- tokenizer_path = os.path.join(model_path, 'tokenizer.json')
132
- logger.info(f"Loading tokenizer from {tokenizer_path}")
133
-
134
- tokenizer = PreTrainedTokenizerFast(
135
- tokenizer_file=tokenizer_path,
136
- unk_token="<|endoftext|>",
137
- bos_token="<|endoftext|>",
138
- eos_token="<|endoftext|>"
139
  )
140
- else:
141
- # Fallback: create basic tokenizer
142
- logger.warning("No tokenizer.json found, creating basic tokenizer")
143
- from transformers import AutoTokenizer
144
 
145
- # Try to load from the model path with local_files_only
146
  try:
147
- tokenizer = AutoTokenizer.from_pretrained(
148
- model_path,
 
 
149
  trust_remote_code=True,
150
- local_files_only=True,
151
- cache_dir='/tmp/tokenizer_cache' # Use temp cache
152
  )
153
- except Exception as e:
154
- logger.error(f"Failed to load tokenizer: {e}")
155
- raise e
156
-
157
- # Set special tokens
158
- if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:
159
- tokenizer.pad_token = tokenizer.eos_token
160
- tokenizer.pad_token_id = tokenizer.eos_token_id
161
-
162
- return tokenizer
 
 
 
 
 
 
 
 
163
 
164
- def _create_model_manually(self, config: Qwen2Config, model_path: str) -> Qwen2ForCausalLM:
165
- """Create model architecture and load weights manually"""
166
-
167
- logger.info("Creating Qwen2ForCausalLM with config")
168
- model = Qwen2ForCausalLM(config)
169
-
170
- # Find safetensors files
171
- safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
172
- logger.info(f"Found {len(safetensors_files)} safetensors files")
173
-
174
- if not safetensors_files:
175
- raise FileNotFoundError("No safetensors files found")
176
-
177
- # Load weights manually with memory optimization
178
- from safetensors.torch import load_file
179
-
180
- # Convert to half precision before loading weights to save memory
181
- model = model.half()
182
- logger.info("Converted model to half precision")
183
-
184
- # Load weights in chunks to avoid memory spikes
185
- state_dict = {}
186
- total_files = len(safetensors_files)
187
 
188
- for i, file in enumerate(sorted(safetensors_files)):
189
- logger.info(f"Loading weights from file {i+1}/{total_files}: {os.path.basename(file)}")
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  try:
192
- # Load partial weights
193
- partial_state_dict = load_file(file)
194
-
195
- # Convert to half precision immediately
196
- partial_state_dict = {k: v.half() for k, v in partial_state_dict.items()}
197
-
198
- # Update state dict
199
- state_dict.update(partial_state_dict)
200
-
201
- # Clear partial dict to free memory
202
- del partial_state_dict
203
-
204
- # Force garbage collection
205
- import gc
206
- gc.collect()
207
-
208
- if torch.cuda.is_available():
209
- torch.cuda.empty_cache()
210
-
211
- logger.info(f"Loaded file {i+1}/{total_files}, current memory usage: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
212
-
213
- except Exception as e:
214
- logger.error(f"Failed to load file {file}: {e}")
215
- raise e
216
-
217
- logger.info(f"Total state dict keys: {len(state_dict)}")
218
-
219
- # Load weights into model
220
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
221
-
222
- if missing_keys:
223
- logger.warning(f"Missing keys: {len(missing_keys)} keys missing")
224
- logger.warning(f"First few missing: {missing_keys[:5]}")
225
-
226
- if unexpected_keys:
227
- logger.warning(f"Unexpected keys: {len(unexpected_keys)} unexpected keys")
228
- logger.warning(f"First few unexpected: {unexpected_keys[:5]}")
229
-
230
- # Clear state dict to free memory
231
- del state_dict
232
- gc.collect()
233
- if torch.cuda.is_available():
234
- torch.cuda.empty_cache()
235
-
236
- # Move to GPU if available
237
- if torch.cuda.is_available():
238
- model = model.cuda()
239
- logger.info(f"Model moved to GPU, final memory usage: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
240
 
241
  model.eval()
242
  return model
@@ -252,7 +153,7 @@ class EndpointHandler:
252
  if not inputs:
253
  return [{"error": "No input provided", "generated_text": ""}]
254
 
255
- # Generation parameters
256
  max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
257
  temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
258
  top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
@@ -264,40 +165,63 @@ class EndpointHandler:
264
  else:
265
  formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
266
 
267
- # Tokenize
268
- input_ids = self.tokenizer.encode(
269
- formatted_input,
270
- return_tensors="pt",
271
- truncation=True,
272
- max_length=3072
273
- )
 
 
 
 
274
 
275
- input_ids = input_ids.to(self.model.device)
 
276
 
277
- # Generate
278
- with torch.no_grad():
279
- outputs = self.model.generate(
280
- input_ids,
281
- max_new_tokens=max_new_tokens,
282
- temperature=temperature,
283
- top_p=top_p,
284
- do_sample=do_sample,
285
- pad_token_id=self.tokenizer.pad_token_id,
286
- eos_token_id=self.tokenizer.eos_token_id,
287
- use_cache=True
288
- )
289
 
290
- # Decode response
291
- generated_ids = outputs[0][input_ids.size(1):]
292
- response = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
293
- response = response.replace("<|im_end|>", "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- return [{
296
- "generated_text": response,
297
- "generated_tokens": len(generated_ids),
298
- "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
299
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  except Exception as e:
302
- logger.error(f"Generation error: {str(e)}")
303
- return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]
 
3
  import json
4
  import os
5
  import glob
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
7
  import logging
8
 
9
  # Set up logging
 
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
  """
16
+ Initialize handler with robust file discovery
17
  """
18
  logger.info(f"Loading model from {path}")
19
 
20
  try:
21
+ # Log directory contents to understand structure
22
+ if os.path.exists(path):
23
+ contents = os.listdir(path)
24
+ logger.info(f"Repository contents: {contents}")
25
+
26
+ # Look for model files in subdirectories
27
+ for item in contents:
28
+ item_path = os.path.join(path, item)
29
+ if os.path.isdir(item_path):
30
+ sub_contents = os.listdir(item_path)
31
+ logger.info(f"Directory {item}: {sub_contents}")
32
+
33
+ # Try to find the actual model path
34
+ model_path = self._find_model_path(path)
35
+ logger.info(f"Using model path: {model_path}")
36
+
37
+ # Load tokenizer - try multiple approaches
38
+ self.tokenizer = self._load_tokenizer(model_path, path)
39
  logger.info("Tokenizer loaded successfully")
40
 
41
+ # Load model
42
+ self.model = self._load_model(model_path, path)
43
  logger.info("Model loaded successfully")
44
 
45
  except Exception as e:
46
  logger.error(f"Failed to initialize: {str(e)}")
47
  raise e
48
 
49
+ def _find_model_path(self, base_path: str) -> str:
50
+ """Find the actual path containing model files"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Check if config.json is in base path
53
+ if os.path.exists(os.path.join(base_path, "config.json")):
54
+ return base_path
55
 
56
+ # Check models/huggingface subdirectory
57
+ hf_path = os.path.join(base_path, "models", "huggingface")
58
+ if os.path.exists(hf_path) and os.path.exists(os.path.join(hf_path, "config.json")):
59
+ return hf_path
60
 
61
+ # Check for any subdirectory with config.json
62
+ for root, dirs, files in os.walk(base_path):
63
+ if "config.json" in files:
64
+ return root
65
 
66
+ # Fallback to base path
67
+ return base_path
 
68
 
69
+ def _load_tokenizer(self, model_path: str, base_path: str):
70
+ """Load tokenizer with fallback methods"""
71
 
72
+ try:
73
+ # Try direct loading from model path
74
+ logger.info(f"Trying to load tokenizer from {model_path}")
75
+ return AutoTokenizer.from_pretrained(
76
+ model_path,
77
+ trust_remote_code=True,
78
+ local_files_only=True
 
 
 
 
 
 
 
 
 
 
 
79
  )
80
+ except Exception as e1:
81
+ logger.warning(f"Failed to load from {model_path}: {e1}")
 
 
82
 
 
83
  try:
84
+ # Try loading from base path
85
+ logger.info(f"Trying to load tokenizer from {base_path}")
86
+ return AutoTokenizer.from_pretrained(
87
+ base_path,
88
  trust_remote_code=True,
89
+ local_files_only=True
 
90
  )
91
+ except Exception as e2:
92
+ logger.warning(f"Failed to load from {base_path}: {e2}")
93
+
94
+ try:
95
+ # Try loading from Hugging Face Hub as fallback
96
+ logger.info("Using fallback tokenizer from Qwen2-7B-Instruct")
97
+ tokenizer = AutoTokenizer.from_pretrained(
98
+ "Qwen/Qwen2-7B-Instruct",
99
+ trust_remote_code=True
100
+ )
101
+
102
+ # Set special tokens
103
+ tokenizer.pad_token = tokenizer.eos_token
104
+ return tokenizer
105
+
106
+ except Exception as e3:
107
+ logger.error(f"All tokenizer loading methods failed: {e3}")
108
+ raise e3
109
 
110
+ def _load_model(self, model_path: str, base_path: str):
111
+ """Load model with fallback methods"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ try:
114
+ # Try direct loading from model path
115
+ logger.info(f"Trying to load model from {model_path}")
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ model_path,
118
+ torch_dtype=torch.float16,
119
+ device_map="auto",
120
+ trust_remote_code=True,
121
+ local_files_only=True,
122
+ low_cpu_mem_usage=True
123
+ )
124
+ except Exception as e1:
125
+ logger.warning(f"Failed to load from {model_path}: {e1}")
126
 
127
  try:
128
+ # Try loading from base path
129
+ logger.info(f"Trying to load model from {base_path}")
130
+ model = AutoModelForCausalLM.from_pretrained(
131
+ base_path,
132
+ torch_dtype=torch.float16,
133
+ device_map="auto",
134
+ trust_remote_code=True,
135
+ local_files_only=True,
136
+ low_cpu_mem_usage=True
137
+ )
138
+ except Exception as e2:
139
+ logger.error(f"Model loading failed from both paths: {e2}")
140
+ raise e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  model.eval()
143
  return model
 
153
  if not inputs:
154
  return [{"error": "No input provided", "generated_text": ""}]
155
 
156
+ # Generation parameters with safety limits
157
  max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
158
  temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
159
  top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
 
165
  else:
166
  formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
167
 
168
+ # Tokenize with error handling
169
+ try:
170
+ input_ids = self.tokenizer.encode(
171
+ formatted_input,
172
+ return_tensors="pt",
173
+ truncation=True,
174
+ max_length=3072
175
+ )
176
+ except Exception as e:
177
+ logger.error(f"Tokenization failed: {e}")
178
+ return [{"error": f"Tokenization failed: {str(e)}", "generated_text": ""}]
179
 
180
+ if input_ids.size(1) == 0:
181
+ return [{"error": "Empty input after tokenization", "generated_text": ""}]
182
 
183
+ # Move to model device
184
+ input_ids = input_ids.to(next(self.model.parameters()).device)
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # Generate with error handling
187
+ try:
188
+ with torch.no_grad():
189
+ outputs = self.model.generate(
190
+ input_ids,
191
+ max_new_tokens=max_new_tokens,
192
+ temperature=temperature,
193
+ top_p=top_p,
194
+ do_sample=do_sample,
195
+ pad_token_id=self.tokenizer.pad_token_id,
196
+ eos_token_id=self.tokenizer.eos_token_id,
197
+ use_cache=True,
198
+ num_return_sequences=1
199
+ )
200
+ except Exception as e:
201
+ logger.error(f"Generation failed: {e}")
202
+ return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]
203
 
204
+ # Decode response
205
+ try:
206
+ generated_ids = outputs[0][input_ids.size(1):]
207
+ response = self.tokenizer.decode(
208
+ generated_ids,
209
+ skip_special_tokens=True
210
+ ).strip()
211
+
212
+ # Clean up response
213
+ response = response.replace("<|im_end|>", "").strip()
214
+
215
+ return [{
216
+ "generated_text": response,
217
+ "generated_tokens": len(generated_ids),
218
+ "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
219
+ }]
220
+
221
+ except Exception as e:
222
+ logger.error(f"Decoding failed: {e}")
223
+ return [{"error": f"Decoding failed: {str(e)}", "generated_text": ""}]
224
 
225
  except Exception as e:
226
+ logger.error(f"Inference error: {str(e)}")
227
+ return [{"error": f"Inference failed: {str(e)}", "generated_text": ""}]