TomBombadyl commited on
Commit
d9f2f1d
·
verified ·
1 Parent(s): b9607cb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -27
handler.py CHANGED
@@ -1,6 +1,8 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
  import logging
5
 
6
  # Set up logging
@@ -11,43 +13,79 @@ class EndpointHandler:
11
  def __init__(self, path: str = ""):
12
  """
13
  Initialize the handler for Qwen2.5-Coder-7B-Instruct-Omni1.1
14
- Simple and robust implementation
15
  """
16
  logger.info(f"Loading model from {path}")
17
 
18
  try:
19
- # Load tokenizer - most robust approach
20
- self.tokenizer = AutoTokenizer.from_pretrained(
 
 
 
 
 
 
 
 
 
 
21
  path,
22
  trust_remote_code=True,
23
- use_fast=False,
24
  padding_side="left"
25
  )
26
 
27
- # Ensure we have proper tokens
28
  if self.tokenizer.pad_token is None:
29
  self.tokenizer.pad_token = self.tokenizer.eos_token
30
 
31
- if self.tokenizer.chat_template is None:
32
- # Set a basic chat template for Qwen
33
- self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
34
 
35
- # Load model
36
- self.model = AutoModelForCausalLM.from_pretrained(
 
37
  path,
38
  torch_dtype=torch.float16,
39
  device_map="auto",
40
  trust_remote_code=True,
41
- low_cpu_mem_usage=True,
42
- use_cache=True
43
  )
44
 
45
  self.model.eval()
46
- logger.info("Model and tokenizer loaded successfully")
47
 
48
  except Exception as e:
49
- logger.error(f"Error loading model: {str(e)}")
50
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
53
  """
@@ -62,18 +100,16 @@ class EndpointHandler:
62
  return [{"error": "No input provided", "generated_text": ""}]
63
 
64
  # Generation parameters
65
- max_new_tokens = parameters.get("max_new_tokens", 512)
66
- temperature = parameters.get("temperature", 0.7)
67
- top_p = parameters.get("top_p", 0.9)
68
  do_sample = parameters.get("do_sample", True)
69
- repetition_penalty = parameters.get("repetition_penalty", 1.1)
70
 
71
- # Prepare input - handle both raw text and pre-formatted chat
72
  if inputs.startswith("<|im_start|>"):
73
- # Already formatted
74
  formatted_input = inputs
75
  else:
76
- # Format as chat
77
  formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
78
 
79
  # Tokenize
@@ -82,7 +118,7 @@ class EndpointHandler:
82
  return_tensors="pt",
83
  add_special_tokens=False,
84
  truncation=True,
85
- max_length=4096 - max_new_tokens # Leave room for generation
86
  )
87
 
88
  if input_ids.size(1) == 0:
@@ -99,7 +135,7 @@ class EndpointHandler:
99
  top_p=top_p,
100
  do_sample=do_sample,
101
  repetition_penalty=repetition_penalty,
102
- pad_token_id=self.tokenizer.eos_token_id,
103
  eos_token_id=self.tokenizer.eos_token_id,
104
  use_cache=True,
105
  num_return_sequences=1
@@ -115,8 +151,6 @@ class EndpointHandler:
115
 
116
  # Clean up response
117
  response = response.strip()
118
-
119
- # Remove any remaining special tokens manually
120
  response = response.replace("<|im_end|>", "").strip()
121
 
122
  return [{
 
1
  from typing import Dict, List, Any
2
  import torch
3
+ import json
4
+ import os
5
+ from transformers import Qwen2TokenizerFast, Qwen2ForCausalLM
6
  import logging
7
 
8
  # Set up logging
 
13
  def __init__(self, path: str = ""):
14
  """
15
  Initialize the handler for Qwen2.5-Coder-7B-Instruct-Omni1.1
16
+ Explicitly using Qwen2 classes to bypass auto-detection
17
  """
18
  logger.info(f"Loading model from {path}")
19
 
20
  try:
21
+ # Check if config exists and log it
22
+ config_path = os.path.join(path, "config.json")
23
+ if os.path.exists(config_path):
24
+ with open(config_path, 'r') as f:
25
+ config = json.load(f)
26
+ logger.info(f"Found config with model_type: {config.get('model_type', 'MISSING')}")
27
+ else:
28
+ logger.warning("No config.json found in repository")
29
+
30
+ # Load tokenizer explicitly as Qwen2
31
+ logger.info("Loading tokenizer as Qwen2TokenizerFast...")
32
+ self.tokenizer = Qwen2TokenizerFast.from_pretrained(
33
  path,
34
  trust_remote_code=True,
 
35
  padding_side="left"
36
  )
37
 
38
+ # Ensure proper tokens
39
  if self.tokenizer.pad_token is None:
40
  self.tokenizer.pad_token = self.tokenizer.eos_token
41
 
42
+ logger.info("Tokenizer loaded successfully")
 
 
43
 
44
+ # Load model explicitly as Qwen2ForCausalLM
45
+ logger.info("Loading model as Qwen2ForCausalLM...")
46
+ self.model = Qwen2ForCausalLM.from_pretrained(
47
  path,
48
  torch_dtype=torch.float16,
49
  device_map="auto",
50
  trust_remote_code=True,
51
+ low_cpu_mem_usage=True
 
52
  )
53
 
54
  self.model.eval()
55
+ logger.info("Model loaded successfully")
56
 
57
  except Exception as e:
58
+ logger.error(f"Error during initialization: {str(e)}")
59
+ # Try alternative loading method
60
+ try:
61
+ logger.info("Attempting alternative loading method...")
62
+
63
+ # Use the models subdirectory path that we saw in your repo
64
+ model_path = os.path.join(path, "models", "huggingface") if os.path.exists(os.path.join(path, "models", "huggingface")) else path
65
+
66
+ self.tokenizer = Qwen2TokenizerFast.from_pretrained(
67
+ model_path,
68
+ trust_remote_code=True,
69
+ local_files_only=True
70
+ )
71
+
72
+ if self.tokenizer.pad_token is None:
73
+ self.tokenizer.pad_token = self.tokenizer.eos_token
74
+
75
+ self.model = Qwen2ForCausalLM.from_pretrained(
76
+ model_path,
77
+ torch_dtype=torch.float16,
78
+ device_map="auto",
79
+ trust_remote_code=True,
80
+ local_files_only=True
81
+ )
82
+
83
+ self.model.eval()
84
+ logger.info("Alternative loading successful")
85
+
86
+ except Exception as e2:
87
+ logger.error(f"Alternative loading also failed: {str(e2)}")
88
+ raise e2
89
 
90
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
91
  """
 
100
  return [{"error": "No input provided", "generated_text": ""}]
101
 
102
  # Generation parameters
103
+ max_new_tokens = min(parameters.get("max_new_tokens", 512), 1024) # Cap at 1024
104
+ temperature = max(0.1, min(parameters.get("temperature", 0.7), 2.0)) # Clamp between 0.1 and 2.0
105
+ top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0)) # Clamp between 0.1 and 1.0
106
  do_sample = parameters.get("do_sample", True)
107
+ repetition_penalty = max(1.0, min(parameters.get("repetition_penalty", 1.1), 2.0))
108
 
109
+ # Format input with Qwen chat template
110
  if inputs.startswith("<|im_start|>"):
 
111
  formatted_input = inputs
112
  else:
 
113
  formatted_input = f"<|im_start|>user\n{inputs}<|im_end|>\n<|im_start|>assistant\n"
114
 
115
  # Tokenize
 
118
  return_tensors="pt",
119
  add_special_tokens=False,
120
  truncation=True,
121
+ max_length=3072 # Leave room for generation
122
  )
123
 
124
  if input_ids.size(1) == 0:
 
135
  top_p=top_p,
136
  do_sample=do_sample,
137
  repetition_penalty=repetition_penalty,
138
+ pad_token_id=self.tokenizer.pad_token_id,
139
  eos_token_id=self.tokenizer.eos_token_id,
140
  use_cache=True,
141
  num_return_sequences=1
 
151
 
152
  # Clean up response
153
  response = response.strip()
 
 
154
  response = response.replace("<|im_end|>", "").strip()
155
 
156
  return [{