TomBombadyl commited on
Commit
3cfb14a
·
verified ·
1 Parent(s): d9f2f1d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +89 -86
handler.py CHANGED
@@ -2,7 +2,8 @@ from typing import Dict, List, Any
2
  import torch
3
  import json
4
  import os
5
- from transformers import Qwen2TokenizerFast, Qwen2ForCausalLM
 
6
  import logging
7
 
8
  # Set up logging
@@ -12,87 +13,105 @@ logger = logging.getLogger(__name__)
12
  class EndpointHandler:
13
  def __init__(self, path: str = ""):
14
  """
15
- Initialize the handler for Qwen2.5-Coder-7B-Instruct-Omni1.1
16
- Explicitly using Qwen2 classes to bypass auto-detection
17
  """
18
  logger.info(f"Loading model from {path}")
19
 
20
  try:
21
- # Check if config exists and log it
22
  config_path = os.path.join(path, "config.json")
 
23
  if os.path.exists(config_path):
24
  with open(config_path, 'r') as f:
25
- config = json.load(f)
26
- logger.info(f"Found config with model_type: {config.get('model_type', 'MISSING')}")
 
 
 
27
  else:
28
- logger.warning("No config.json found in repository")
29
-
30
- # Load tokenizer explicitly as Qwen2
31
- logger.info("Loading tokenizer as Qwen2TokenizerFast...")
32
- self.tokenizer = Qwen2TokenizerFast.from_pretrained(
33
- path,
34
- trust_remote_code=True,
35
- padding_side="left"
36
- )
37
 
38
- # Ensure proper tokens
39
- if self.tokenizer.pad_token is None:
40
- self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  logger.info("Tokenizer loaded successfully")
43
 
44
- # Load model explicitly as Qwen2ForCausalLM
45
- logger.info("Loading model as Qwen2ForCausalLM...")
46
- self.model = Qwen2ForCausalLM.from_pretrained(
47
- path,
48
- torch_dtype=torch.float16,
49
- device_map="auto",
50
- trust_remote_code=True,
51
- low_cpu_mem_usage=True
52
- )
53
 
54
- self.model.eval()
55
- logger.info("Model loaded successfully")
56
 
57
- except Exception as e:
58
- logger.error(f"Error during initialization: {str(e)}")
59
- # Try alternative loading method
60
- try:
61
- logger.info("Attempting alternative loading method...")
62
-
63
- # Use the models subdirectory path that we saw in your repo
64
- model_path = os.path.join(path, "models", "huggingface") if os.path.exists(os.path.join(path, "models", "huggingface")) else path
65
-
66
- self.tokenizer = Qwen2TokenizerFast.from_pretrained(
67
- model_path,
68
- trust_remote_code=True,
69
- local_files_only=True
70
- )
71
 
72
- if self.tokenizer.pad_token is None:
73
- self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
74
 
75
- self.model = Qwen2ForCausalLM.from_pretrained(
76
- model_path,
77
- torch_dtype=torch.float16,
78
- device_map="auto",
79
- trust_remote_code=True,
80
- local_files_only=True
81
- )
82
-
83
- self.model.eval()
84
- logger.info("Alternative loading successful")
85
 
86
- except Exception as e2:
87
- logger.error(f"Alternative loading also failed: {str(e2)}")
88
- raise e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
91
  """
92
  Handle inference requests
93
  """
94
  try:
95
- # Extract inputs and parameters
96
  inputs = data.get("inputs", "")
97
  parameters = data.get("parameters", {})
98
 
@@ -100,13 +119,12 @@ class EndpointHandler:
100
  return [{"error": "No input provided", "generated_text": ""}]
101
 
102
  # Generation parameters
103
- max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024) # Cap at 1024
104
- temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0)) # Clamp between 0.1 and 2.0
105
- top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0)) # Clamp between 0.1 and 1.0
106
  do_sample = parameters.get("do_sample", True)
107
- repetition_penalty = max(1.0, min(parameters.get("repetition_penalty", 1.1), 2.0))
108
 
109
- # Format input with Qwen chat template
110
  if inputs.startswith("<|im_start|>"):
111
  formatted_input = inputs
112
  else:
@@ -116,15 +134,12 @@ class EndpointHandler:
116
  input_ids = self.tokenizer.encode(
117
  formatted_input,
118
  return_tensors="pt",
119
- add_special_tokens=False,
120
  truncation=True,
121
- max_length=3072 # Leave room for generation
122
  )
123
 
124
- if input_ids.size(1) == 0:
125
- return [{"error": "Input tokenization failed", "generated_text": ""}]
126
-
127
- input_ids = input_ids.to(self.model.device)
128
 
129
  # Generate
130
  with torch.no_grad():
@@ -134,23 +149,14 @@ class EndpointHandler:
134
  temperature=temperature,
135
  top_p=top_p,
136
  do_sample=do_sample,
137
- repetition_penalty=repetition_penalty,
138
  pad_token_id=self.tokenizer.pad_token_id,
139
  eos_token_id=self.tokenizer.eos_token_id,
140
- use_cache=True,
141
- num_return_sequences=1
142
  )
143
 
144
- # Decode response (only new tokens)
145
  generated_ids = outputs[0][input_ids.size(1):]
146
- response = self.tokenizer.decode(
147
- generated_ids,
148
- skip_special_tokens=True,
149
- clean_up_tokenization_spaces=True
150
- )
151
-
152
- # Clean up response
153
- response = response.strip()
154
  response = response.replace("<|im_end|>", "").strip()
155
 
156
  return [{
@@ -161,7 +167,4 @@ class EndpointHandler:
161
 
162
  except Exception as e:
163
  logger.error(f"Generation error: {str(e)}")
164
- return [{
165
- "error": f"Generation failed: {str(e)}",
166
- "generated_text": ""
167
- }]
 
2
  import torch
3
  import json
4
  import os
5
+ from transformers import PreTrainedTokenizerFast, PreTrainedModel
6
+ from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
7
  import logging
8
 
9
  # Set up logging
 
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
  """
16
+ Initialize handler with manual model loading to bypass auto-detection
 
17
  """
18
  logger.info(f"Loading model from {path}")
19
 
20
  try:
21
+ # Manual config loading and creation
22
  config_path = os.path.join(path, "config.json")
23
+
24
  if os.path.exists(config_path):
25
  with open(config_path, 'r') as f:
26
+ config_dict = json.load(f)
27
+ logger.info(f"Loaded config: {config_dict.get('model_type', 'UNKNOWN')}")
28
+
29
+ # Create Qwen2Config manually
30
+ config = Qwen2Config(**config_dict)
31
  else:
32
+ logger.warning("No config.json found, using default Qwen2Config")
33
+ config = Qwen2Config()
34
+
35
+ # Load tokenizer manually without auto-detection
36
+ logger.info("Loading tokenizer manually...")
37
+ tokenizer_path = os.path.join(path, "tokenizer.json")
 
 
 
38
 
39
+ if os.path.exists(tokenizer_path):
40
+ # Load tokenizer from tokenizer.json directly
41
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
42
+ else:
43
+ # Try loading from vocab files
44
+ vocab_path = os.path.join(path, "vocab.json")
45
+ merges_path = os.path.join(path, "merges.txt")
46
+
47
+ if os.path.exists(vocab_path):
48
+ self.tokenizer = PreTrainedTokenizerFast(
49
+ tokenizer_file=None,
50
+ vocab_file=vocab_path,
51
+ merges_file=merges_path if os.path.exists(merges_path) else None
52
+ )
53
+ else:
54
+ # Fallback: create basic tokenizer
55
+ from transformers import AutoTokenizer
56
+ logger.warning("Using fallback tokenizer loading...")
57
+ self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
58
+
59
+ # Set special tokens
60
+ if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
61
+ self.tokenizer.pad_token = "<|endoftext|>"
62
+ self.tokenizer.pad_token_id = 151643
63
+
64
+ if not hasattr(self.tokenizer, 'eos_token') or self.tokenizer.eos_token is None:
65
+ self.tokenizer.eos_token = "<|endoftext|>"
66
+ self.tokenizer.eos_token_id = 151643
67
 
68
  logger.info("Tokenizer loaded successfully")
69
 
70
+ # Load model manually with the config
71
+ logger.info("Loading model manually...")
72
+ self.model = Qwen2ForCausalLM(config)
 
 
 
 
 
 
73
 
74
+ # Load state dict manually
75
+ safetensors_files = [f for f in os.listdir(path) if f.endswith('.safetensors')]
76
 
77
+ if safetensors_files:
78
+ logger.info(f"Loading weights from {len(safetensors_files)} safetensors files")
79
+ from safetensors.torch import load_file
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ state_dict = {}
82
+ for file in sorted(safetensors_files):
83
+ file_path = os.path.join(path, file)
84
+ partial_state_dict = load_file(file_path)
85
+ state_dict.update(partial_state_dict)
86
 
87
+ # Load the state dict
88
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
89
 
90
+ if missing_keys:
91
+ logger.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5
92
+ if unexpected_keys:
93
+ logger.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5
94
+ else:
95
+ logger.error("No safetensors files found!")
96
+ raise FileNotFoundError("No model weights found")
97
+
98
+ # Move to GPU and set to eval mode
99
+ self.model = self.model.half() # Convert to float16
100
+ if torch.cuda.is_available():
101
+ self.model = self.model.cuda()
102
+ self.model.eval()
103
+
104
+ logger.info("Model loaded successfully")
105
+
106
+ except Exception as e:
107
+ logger.error(f"Failed to load model: {str(e)}")
108
+ raise e
109
 
110
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
111
  """
112
  Handle inference requests
113
  """
114
  try:
 
115
  inputs = data.get("inputs", "")
116
  parameters = data.get("parameters", {})
117
 
 
119
  return [{"error": "No input provided", "generated_text": ""}]
120
 
121
  # Generation parameters
122
+ max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024)
123
+ temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0))
124
+ top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
125
  do_sample = parameters.get("do_sample", True)
 
126
 
127
+ # Format input
128
  if inputs.startswith("<|im_start|>"):
129
  formatted_input = inputs
130
  else:
 
134
  input_ids = self.tokenizer.encode(
135
  formatted_input,
136
  return_tensors="pt",
 
137
  truncation=True,
138
+ max_length=3072
139
  )
140
 
141
+ if torch.cuda.is_available():
142
+ input_ids = input_ids.cuda()
 
 
143
 
144
  # Generate
145
  with torch.no_grad():
 
149
  temperature=temperature,
150
  top_p=top_p,
151
  do_sample=do_sample,
 
152
  pad_token_id=self.tokenizer.pad_token_id,
153
  eos_token_id=self.tokenizer.eos_token_id,
154
+ use_cache=True
 
155
  )
156
 
157
+ # Decode response
158
  generated_ids = outputs[0][input_ids.size(1):]
159
+ response = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
160
  response = response.replace("<|im_end|>", "").strip()
161
 
162
  return [{
 
167
 
168
  except Exception as e:
169
  logger.error(f"Generation error: {str(e)}")
170
+ return [{"error": f"Generation failed: {str(e)}", "generated_text": ""}]