TomBombadyl commited on
Commit
0467b75
·
verified ·
1 Parent(s): 0711c19

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +200 -202
handler.py CHANGED
@@ -11,23 +11,24 @@ logger = logging.getLogger(__name__)
11
  class EndpointHandler:
12
  def __init__(self, path: str = ""):
13
  """
14
- Initialize handler that completely bypasses HF auto-detection
15
  """
16
  logger.info(f"Loading model from {path}")
17
 
18
  try:
19
- # Set environment variables to avoid auto-detection issues
20
- os.environ['TRANSFORMERS_OFFLINE'] = '1'
21
- os.environ['HF_DATASETS_OFFLINE'] = '1'
22
- os.environ['HF_HUB_OFFLINE'] = '1'
23
 
24
- # Find model files
25
- model_path = self._discover_model_files(path)
26
- logger.info(f"Model files found at: {model_path}")
 
27
 
28
- # Load components manually
29
- self.tokenizer = self._load_tokenizer_manual(model_path)
30
- self.model = self._load_model_manual(model_path)
 
 
31
 
32
  logger.info("Model and tokenizer loaded successfully")
33
 
@@ -35,239 +36,236 @@ class EndpointHandler:
35
  logger.error(f"Failed to initialize: {str(e)}")
36
  raise e
37
 
38
- def _discover_model_files(self, base_path: str) -> str:
39
- """Find where the actual model files are located"""
40
-
41
- logger.info(f"Searching for model files in: {base_path}")
42
-
43
- # List all contents
44
- if os.path.exists(base_path):
45
- contents = os.listdir(base_path)
46
- logger.info(f"Base directory contents: {contents}")
47
-
48
- # Check for config.json in base path
49
- if "config.json" in contents:
50
- logger.info("Found config.json in base directory")
51
- return base_path
52
-
53
- # Check models subdirectories
54
- for item in contents:
55
- if os.path.isdir(os.path.join(base_path, item)):
56
- sub_path = os.path.join(base_path, item)
57
- sub_contents = os.listdir(sub_path)
58
- logger.info(f"Subdirectory {item}: {sub_contents}")
59
-
60
- if "config.json" in sub_contents:
61
- logger.info(f"Found config.json in {item} subdirectory")
62
- return sub_path
63
-
64
- # Search recursively
65
- for root, dirs, files in os.walk(base_path):
66
- if "config.json" in files:
67
- logger.info(f"Found config.json in {root}")
68
- return root
69
-
70
- raise FileNotFoundError(f"No config.json found in {base_path} or subdirectories")
71
-
72
- def _load_tokenizer_manual(self, model_path: str):
73
- """Load tokenizer completely manually"""
74
 
75
- logger.info("Loading tokenizer manually...")
76
 
77
- # Check what tokenizer files exist
78
- tokenizer_files = []
79
- for file in os.listdir(model_path):
80
- if file in ['tokenizer.json', 'tokenizer_config.json', 'vocab.json']:
81
- tokenizer_files.append(file)
82
 
83
- logger.info(f"Found tokenizer files: {tokenizer_files}")
84
-
85
- if 'tokenizer.json' in tokenizer_files:
86
- # Load from tokenizer.json directly
87
- from transformers import PreTrainedTokenizerFast
88
- tokenizer_path = os.path.join(model_path, 'tokenizer.json')
89
- logger.info(f"Loading tokenizer from {tokenizer_path}")
90
-
91
- tokenizer = PreTrainedTokenizerFast(
92
- tokenizer_file=tokenizer_path,
93
- unk_token="<|endoftext|>",
94
- bos_token="<|endoftext|>",
95
- eos_token="<|endoftext|>"
96
- )
97
- else:
98
- # Create a basic tokenizer
99
- logger.warning("No tokenizer.json found, creating basic tokenizer")
100
- from transformers import PreTrainedTokenizerFast
101
-
102
- # Create minimal tokenizer
103
- tokenizer = PreTrainedTokenizerFast(
104
- tokenizer_file=None,
105
- vocab_size=151936, # Qwen2 default vocab size
106
- unk_token="<|endoftext|>",
107
- bos_token="<|endoftext|>",
108
- eos_token="<|endoftext|>",
109
- pad_token="<|endoftext|>"
110
- )
111
-
112
- # Set special tokens
113
  if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:
114
- tokenizer.pad_token = "<|endoftext|>"
115
- tokenizer.pad_token_id = 151643
116
-
117
- if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
118
- tokenizer.eos_token = "<|endoftext|>"
119
- tokenizer.eos_token_id = 151643
120
 
 
121
  return tokenizer
122
 
123
- def _load_model_manual(self, model_path: str):
124
- """Load model completely manually with memory optimization"""
 
125
 
126
- logger.info("Loading model manually...")
127
 
128
- # Check GPU availability and memory
129
  if torch.cuda.is_available():
130
  logger.info(f"CUDA available: {torch.cuda.get_device_name()}")
131
- logger.info(f"GPU memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
132
  logger.info(f"GPU memory total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB")
133
  else:
134
  logger.warning("CUDA not available, using CPU")
135
 
136
- # Load config manually
137
- config_path = os.path.join(model_path, "config.json")
138
- with open(config_path, 'r') as f:
139
- config_dict = json.load(f)
 
 
 
 
 
 
 
 
 
140
 
141
- logger.info(f"Config loaded: {config_dict.get('model_type', 'UNKNOWN')}")
 
142
 
143
- # Create model architecture manually
144
- from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
 
 
 
 
 
 
 
 
145
 
146
- # Ensure model_type is set correctly
147
- if 'model_type' not in config_dict:
148
- config_dict['model_type'] = 'qwen2'
149
- logger.info("Set model_type to 'qwen2'")
150
 
151
- # Create config object
152
- config = Qwen2Config(**config_dict)
 
 
153
 
154
- # Create model
155
- model = Qwen2ForCausalLM(config)
156
- logger.info("Model architecture created")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  if torch.cuda.is_available():
159
- logger.info(f"GPU memory after model creation: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
160
 
161
- # Load weights manually from safetensors with memory optimization
162
- import glob
163
- safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
164
- logger.info(f"Found {len(safetensors_files)} safetensors files")
 
165
 
166
- if safetensors_files:
167
- from safetensors.torch import load_file
 
 
168
 
169
- # Load weights directly into model without accumulating in state_dict
170
- for i, file in enumerate(sorted(safetensors_files)):
171
- logger.info(f"Loading weights from file {i+1}/{len(safetensors_files)}: {os.path.basename(file)}")
172
-
173
- # Load partial weights
174
- partial_state_dict = load_file(file)
175
-
176
- if torch.cuda.is_available():
177
- logger.info(f"GPU memory after loading file {i+1}: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
178
-
179
- # Load this partial state dict directly into the model
180
- missing_keys, unexpected_keys = model.load_state_dict(partial_state_dict, strict=False)
181
-
182
- # Clear partial dict immediately to free memory
183
- del partial_state_dict
184
-
185
- # Force garbage collection
186
- import gc
187
- gc.collect()
188
-
189
- if torch.cuda.is_available():
190
- torch.cuda.empty_cache()
191
- logger.info(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
192
 
193
- # Convert to half precision and move to GPU
194
- logger.info("Converting model to half precision...")
195
- model = model.half()
 
 
196
 
197
- if torch.cuda.is_available():
198
- logger.info(f"GPU memory after half precision: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
199
- model = model.cuda()
200
- logger.info(f"GPU memory after moving to GPU: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
201
 
202
- model.eval()
203
- logger.info("Model loaded successfully and set to eval mode")
204
- return model
205
 
206
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
207
- """
208
- Handle inference requests
209
- """
210
  try:
211
- inputs = data.get("inputs", "")
212
- parameters = data.get("parameters", {})
213
 
214
- if not inputs:
215
- return [{"error": "No input provided", "generated_text": ""}]
 
 
 
 
 
216
 
217
  # Generation parameters
218
- max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
219
- temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
220
- top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
221
- do_sample = parameters.get("do_sample", True)
 
 
 
222
 
223
- # Format input for Qwen chat template
224
- if inputs.startswith("<|im_start|>"):
225
- formatted_input = inputs
226
- else:
227
- formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
228
 
229
- # Tokenize
230
- input_ids = self.tokenizer.encode(
231
- formatted_input,
232
- return_tensors="pt",
233
- truncation=True,
234
- max_length=3072
 
 
 
 
 
 
235
  )
236
 
237
- if input_ids.size(1) == 0:
238
- return [{"error": "Empty input after tokenization", "generated_text": ""}]
239
-
240
- input_ids = input_ids.to(self.model.device)
241
-
242
  # Generate
243
  with torch.no_grad():
244
- outputs = self.model.generate(
245
- input_ids,
246
- max_new_tokens=max_new_tokens,
247
- temperature=temperature,
248
- top_p=top_p,
249
- do_sample=do_sample,
250
- pad_token_id=self.tokenizer.pad_token_id,
251
- eos_token_id=self.tokenizer.eos_token_id,
252
- use_cache=True
253
- )
254
 
255
- # Decode response
256
- generated_ids = outputs[0][input_ids.size(1):]
257
- response = self.tokenizer.decode(
258
- generated_ids,
259
- skip_special_tokens=True
260
- ).strip()
 
 
 
261
 
262
- # Clean up response
263
- response = response.replace("<|im_end|>", "").strip()
 
 
264
 
265
- return [{
266
- "generated_text": response,
267
- "generated_tokens": len(generated_ids),
268
- "finish_reason": "eos_token" if self.tokenizer.eos_token_id in generated_ids else "length"
269
- }]
 
 
 
 
 
 
 
 
 
 
270
 
271
  except Exception as e:
272
  logger.error(f"Generation error: {str(e)}")
273
- return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]
 
 
 
 
 
 
 
11
  class EndpointHandler:
12
  def __init__(self, path: str = ""):
13
  """
14
+ Initialize handler using CTransformers format for memory efficiency
15
  """
16
  logger.info(f"Loading model from {path}")
17
 
18
  try:
19
+ # Use CTransformers format for lower memory usage
20
+ ctransformers_path = os.path.join(path, "models", "ctransformers")
 
 
21
 
22
+ if not os.path.exists(ctransformers_path):
23
+ logger.warning(f"CTransformers path not found: {ctransformers_path}")
24
+ logger.info("Falling back to HuggingFace format")
25
+ ctransformers_path = path
26
 
27
+ logger.info(f"Using model path: {ctransformers_path}")
28
+
29
+ # Load components using the working handler approach
30
+ self.tokenizer = self._load_tokenizer(ctransformers_path)
31
+ self.model = self._load_model(ctransformers_path)
32
 
33
  logger.info("Model and tokenizer loaded successfully")
34
 
 
36
  logger.error(f"Failed to initialize: {str(e)}")
37
  raise e
38
 
39
+ def _load_tokenizer(self, model_path: str):
40
+ """Load tokenizer using AutoTokenizer"""
41
+ logger.info("Loading tokenizer...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ from transformers import AutoTokenizer
44
 
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ model_path,
47
+ trust_remote_code=True,
48
+ use_fast=True,
49
+ )
50
 
51
+ # Ensure special tokens are set
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+ tokenizer.pad_token_id = tokenizer.eos_token_id
 
 
 
 
55
 
56
+ logger.info("Tokenizer loaded successfully")
57
  return tokenizer
58
 
59
+ def _load_model(self, model_path: str):
60
+ """Load model using AutoModelForCausalLM with memory optimization"""
61
+ logger.info("Loading model with memory optimization...")
62
 
63
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
64
 
65
+ # Check GPU availability
66
  if torch.cuda.is_available():
67
  logger.info(f"CUDA available: {torch.cuda.get_device_name()}")
 
68
  logger.info(f"GPU memory total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB")
69
  else:
70
  logger.warning("CUDA not available, using CPU")
71
 
72
+ # Memory optimization settings
73
+ device_map = "auto" if torch.cuda.is_available() else None
74
+ gpu_mem = os.environ.get("GPU_MAX_MEM", "10GiB") # Conservative for 12GB limit
75
+ cpu_mem = os.environ.get("CPU_MAX_MEM", "24GiB")
76
+ max_memory = {0: gpu_mem, "cpu": cpu_mem} if torch.cuda.is_available() else None
77
+
78
+ # Offload folder for memory management
79
+ offload_folder = os.environ.get("OFFLOAD_FOLDER", "/tmp/hf-offload")
80
+ try:
81
+ os.makedirs(offload_folder, exist_ok=True)
82
+ except OSError:
83
+ offload_folder = "/tmp/hf-offload"
84
+ os.makedirs(offload_folder, exist_ok=True)
85
 
86
+ # 8-bit quantization for memory efficiency
87
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
88
 
89
+ # Load model with all optimizations
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ model_path,
92
+ trust_remote_code=True,
93
+ device_map=device_map,
94
+ quantization_config=bnb_config,
95
+ low_cpu_mem_usage=True,
96
+ offload_folder=offload_folder if device_map == "auto" else None,
97
+ max_memory=max_memory,
98
+ )
99
 
100
+ model.eval()
 
 
 
101
 
102
+ # Set context window
103
+ self.max_context = getattr(model.config, "max_position_embeddings", None) or getattr(self.tokenizer, "model_max_length", 4096)
104
+ if self.max_context is None or self.max_context == int(1e30):
105
+ self.max_context = 4096
106
 
107
+ # Set token IDs
108
+ self.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
109
+ self.eos_token_id = self.tokenizer.eos_token_id
110
+
111
+ logger.info("Model loaded successfully with memory optimization")
112
+ return model
113
+
114
+ def _build_prompt(self, data: Dict[str, Any]) -> str:
115
+ """Build prompt using chat template or direct input"""
116
+ # Accept either "messages" (chat) or "inputs"/"prompt" (single-turn)
117
+ if "messages" in data and isinstance(data["messages"], list):
118
+ return self.tokenizer.apply_chat_template(
119
+ data["messages"],
120
+ tokenize=False,
121
+ add_generation_prompt=True
122
+ )
123
+
124
+ user_text = data.get("inputs") or data.get("prompt") or ""
125
+ if isinstance(user_text, str):
126
+ messages = [{"role": "user", "content": user_text}]
127
+ return self.tokenizer.apply_chat_template(
128
+ messages,
129
+ tokenize=False,
130
+ add_generation_prompt=True
131
+ )
132
+
133
+ return str(user_text)
134
+
135
+ def _prepare_inputs(self, prompt: str, max_new_tokens: int, params: Dict[str, Any]) -> Dict[str, torch.Tensor]:
136
+ """Prepare inputs with proper tokenization"""
137
+ # Keep room for generation
138
+ max_input_tokens = int(params.get("max_input_tokens", max(self.max_context - max_new_tokens - 8, 256)))
139
+
140
+ model_inputs = self.tokenizer(
141
+ prompt,
142
+ return_tensors="pt",
143
+ truncation=True,
144
+ max_length=max_input_tokens,
145
+ )
146
 
147
  if torch.cuda.is_available():
148
+ model_inputs = {k: v.to(self.model.device) for k, v in model_inputs.items()}
149
 
150
+ return model_inputs
151
+
152
+ def _stopping(self, params: Dict[str, Any]):
153
+ """Create stopping criteria"""
154
+ from transformers import StoppingCriteria, StoppingCriteriaList
155
 
156
+ class StopOnSequences(StoppingCriteria):
157
+ def __init__(self, stop_sequences: List[List[int]]):
158
+ super().__init__()
159
+ self.stop_sequences = [torch.tensor(x, dtype=torch.long) for x in stop_sequences if len(x) > 0]
160
 
161
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
162
+ if input_ids.shape[0] == 0 or not self.stop_sequences:
163
+ return False
164
+ generated = input_ids[0]
165
+ for seq in self.stop_sequences:
166
+ if generated.shape[0] >= seq.shape[0] and torch.equal(generated[-seq.shape[0]:], seq.to(generated.device)):
167
+ return True
168
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ stop = params.get("stop", [])
171
+ if isinstance(stop, str):
172
+ stop = [stop]
173
+ if not isinstance(stop, list):
174
+ stop = []
175
 
176
+ stop_ids = [self.tokenizer.encode(s, add_special_tokens=False) for s in stop]
177
+ criteria = []
178
+ if stop_ids:
179
+ criteria.append(StopOnSequences(stop_ids))
180
 
181
+ return StoppingCriteriaList(criteria)
 
 
182
 
183
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
184
+ """Handle inference requests with proper error handling"""
 
 
185
  try:
186
+ params = data.get("parameters", {}) or {}
 
187
 
188
+ # Set seed if provided
189
+ seed = params.get("seed")
190
+ if seed is not None:
191
+ try:
192
+ torch.manual_seed(int(seed))
193
+ except (ValueError, TypeError):
194
+ pass
195
 
196
  # Generation parameters
197
+ max_new_tokens = int(params.get("max_new_tokens", 512))
198
+ temperature = float(params.get("temperature", 0.2))
199
+ top_p = float(params.get("top_p", 0.9))
200
+ top_k = int(params.get("top_k", 50))
201
+ repetition_penalty = float(params.get("repetition_penalty", 1.05))
202
+ num_beams = int(params.get("num_beams", 1))
203
+ do_sample = bool(params.get("do_sample", temperature > 0 and num_beams == 1))
204
 
205
+ # Build prompt
206
+ prompt = self._build_prompt(data)
207
+ model_inputs = self._prepare_inputs(prompt, max_new_tokens, params)
208
+ input_length = model_inputs["input_ids"].shape[-1]
 
209
 
210
+ # Generation kwargs
211
+ gen_kwargs = dict(
212
+ max_new_tokens=max_new_tokens,
213
+ do_sample=do_sample,
214
+ temperature=max(0.0, temperature),
215
+ top_p=top_p,
216
+ top_k=top_k,
217
+ repetition_penalty=repetition_penalty,
218
+ num_beams=num_beams,
219
+ eos_token_id=self.eos_token_id,
220
+ pad_token_id=self.pad_token_id,
221
+ stopping_criteria=self._stopping(params),
222
  )
223
 
 
 
 
 
 
224
  # Generate
225
  with torch.no_grad():
226
+ output_ids = self.model.generate(**model_inputs, **gen_kwargs)
227
+
228
+ # Slice off the prompt
229
+ gen_ids = output_ids[0][input_length:]
230
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
 
 
 
 
 
231
 
232
+ # Apply text-side stop strings if provided
233
+ stop = params.get("stop", [])
234
+ if isinstance(stop, str):
235
+ stop = [stop]
236
+ for s in stop or []:
237
+ idx = text.find(s)
238
+ if idx != -1:
239
+ text = text[:idx]
240
+ break
241
 
242
+ # Token accounting
243
+ prompt_tokens = int(input_length)
244
+ completion_tokens = int(gen_ids.shape[-1])
245
+ total_tokens = prompt_tokens + completion_tokens
246
 
247
+ return {
248
+ "generated_text": text,
249
+ "input_tokens": prompt_tokens,
250
+ "generated_tokens": completion_tokens,
251
+ "total_tokens": total_tokens,
252
+ "params": {
253
+ "max_new_tokens": max_new_tokens,
254
+ "temperature": temperature,
255
+ "top_p": top_p,
256
+ "top_k": top_k,
257
+ "repetition_penalty": repetition_penalty,
258
+ "num_beams": num_beams,
259
+ "do_sample": do_sample,
260
+ },
261
+ }
262
 
263
  except Exception as e:
264
  logger.error(f"Generation error: {str(e)}")
265
+ return {
266
+ "error": f"Generation failed: {str(e)}",
267
+ "generated_text": "",
268
+ "input_tokens": 0,
269
+ "generated_tokens": 0,
270
+ "total_tokens": 0
271
+ }