Update files
Browse files- chat_with_models.py +315 -0
- lm_eval.sh +11 -0
chat_with_models.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Interactive chat script for any model with automatic chat template support.
|
4 |
+
Usage: python chat_with_models.py <model_folder_name> [--assistant]
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList
|
11 |
+
import warnings
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
# Suppress warnings for cleaner output
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
class StopSequenceCriteria(StoppingCriteria):
|
18 |
+
def __init__(self, tokenizer, stop_sequences, prompt_length):
|
19 |
+
self.tokenizer = tokenizer
|
20 |
+
self.stop_sequences = stop_sequences
|
21 |
+
self.prompt_length = prompt_length
|
22 |
+
self.triggered_stop_sequence = None
|
23 |
+
|
24 |
+
def __call__(self, input_ids, scores, **kwargs):
|
25 |
+
# Only check the newly generated part (after the prompt)
|
26 |
+
if input_ids.shape[1] <= self.prompt_length:
|
27 |
+
return False
|
28 |
+
|
29 |
+
# Decode only the newly generated tokens
|
30 |
+
new_tokens = input_ids[0][self.prompt_length:]
|
31 |
+
new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
32 |
+
|
33 |
+
# Check if any stop sequence appears in the newly generated text
|
34 |
+
for stop_seq in self.stop_sequences:
|
35 |
+
if stop_seq in new_text:
|
36 |
+
return True
|
37 |
+
return False
|
38 |
+
|
39 |
+
class ModelChatter:
|
40 |
+
def __init__(self, model_folder, force_assistant_template=False):
|
41 |
+
self.model_folder = model_folder
|
42 |
+
self.hf_path = os.path.join(model_folder, 'hf')
|
43 |
+
self.model = None
|
44 |
+
self.tokenizer = None
|
45 |
+
self.pipeline = None
|
46 |
+
self.conversation_history = []
|
47 |
+
self.force_assistant_template = force_assistant_template
|
48 |
+
|
49 |
+
def load_model(self):
|
50 |
+
"""Load the model and tokenizer."""
|
51 |
+
try:
|
52 |
+
print(f"🔄 Loading {self.model_folder}...")
|
53 |
+
|
54 |
+
# Load tokenizer
|
55 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path)
|
56 |
+
if self.tokenizer.pad_token is None:
|
57 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
58 |
+
|
59 |
+
# Handle chat template assignment
|
60 |
+
if self.force_assistant_template:
|
61 |
+
print(f"📝 Forcing User: Assistant: chat template...")
|
62 |
+
custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}"""
|
63 |
+
self.tokenizer.chat_template = custom_template
|
64 |
+
print(f"✅ User: Assistant: chat template forced")
|
65 |
+
elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
|
66 |
+
print(f"📝 No chat template found, assigning custom template...")
|
67 |
+
custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}"""
|
68 |
+
self.tokenizer.chat_template = custom_template
|
69 |
+
print(f"✅ Custom chat template assigned")
|
70 |
+
else:
|
71 |
+
print(f"✅ Model has existing chat template")
|
72 |
+
|
73 |
+
# Load model
|
74 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
75 |
+
self.hf_path,
|
76 |
+
device_map=None,
|
77 |
+
torch_dtype=torch.float16,
|
78 |
+
trust_remote_code=True
|
79 |
+
)
|
80 |
+
|
81 |
+
# Move to appropriate device
|
82 |
+
if torch.cuda.is_available():
|
83 |
+
self.model.to("cuda:0")
|
84 |
+
device = "cuda:0"
|
85 |
+
elif torch.backends.mps.is_available():
|
86 |
+
self.model.to("mps")
|
87 |
+
device = "mps"
|
88 |
+
else:
|
89 |
+
self.model.to("cpu")
|
90 |
+
device = "cpu"
|
91 |
+
|
92 |
+
print(f" 📱 Using device: {device}")
|
93 |
+
|
94 |
+
# Create pipeline
|
95 |
+
self.pipeline = pipeline(
|
96 |
+
"text-generation",
|
97 |
+
model=self.model,
|
98 |
+
tokenizer=self.tokenizer,
|
99 |
+
device_map="auto",
|
100 |
+
torch_dtype=torch.float16
|
101 |
+
)
|
102 |
+
|
103 |
+
print(f" ✅ {self.model_folder} loaded successfully")
|
104 |
+
return True
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
print(f" ❌ Failed to load {self.model_folder}: {str(e)}")
|
108 |
+
return False
|
109 |
+
|
110 |
+
def format_chat_prompt(self, user_message):
|
111 |
+
"""Format the conversation history and new user message using the chat template."""
|
112 |
+
# Add the new user message to conversation history
|
113 |
+
self.conversation_history.append({"role": "user", "content": user_message})
|
114 |
+
|
115 |
+
# Format using the tokenizer's chat template
|
116 |
+
try:
|
117 |
+
formatted_prompt = self.tokenizer.apply_chat_template(
|
118 |
+
self.conversation_history,
|
119 |
+
tokenize=False,
|
120 |
+
add_generation_prompt=True
|
121 |
+
)
|
122 |
+
return formatted_prompt
|
123 |
+
except Exception as e:
|
124 |
+
print(f"❌ Error formatting chat prompt: {str(e)}")
|
125 |
+
return None
|
126 |
+
|
127 |
+
def generate_response(self, user_message, max_length=512):
|
128 |
+
"""Generate a response to the user message."""
|
129 |
+
try:
|
130 |
+
# Format the chat prompt
|
131 |
+
formatted_prompt = self.format_chat_prompt(user_message)
|
132 |
+
if formatted_prompt is None:
|
133 |
+
return "❌ Failed to format chat prompt"
|
134 |
+
|
135 |
+
# Generate response with streaming
|
136 |
+
print("🤖 Response: ", end="", flush=True)
|
137 |
+
|
138 |
+
# Use the model directly for streaming with TextStreamer
|
139 |
+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
|
142 |
+
elif torch.backends.mps.is_available():
|
143 |
+
inputs = {k: v.to("mps") for k, v in inputs.items()}
|
144 |
+
|
145 |
+
# Create a streamer that prints tokens as they're generated
|
146 |
+
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
147 |
+
|
148 |
+
# Define stop sequences
|
149 |
+
stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"]
|
150 |
+
|
151 |
+
# Create stopping criteria
|
152 |
+
prompt_length = inputs['input_ids'].shape[1]
|
153 |
+
stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length)
|
154 |
+
|
155 |
+
# Generate with streaming
|
156 |
+
with torch.no_grad():
|
157 |
+
outputs = self.model.generate(
|
158 |
+
**inputs,
|
159 |
+
max_new_tokens=max_length,
|
160 |
+
do_sample=True,
|
161 |
+
temperature=0.7,
|
162 |
+
top_p=0.9,
|
163 |
+
repetition_penalty=1.1,
|
164 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
165 |
+
streamer=streamer,
|
166 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
167 |
+
stopping_criteria=StoppingCriteriaList([stopping_criteria])
|
168 |
+
)
|
169 |
+
|
170 |
+
# Decode the full response for conversation history
|
171 |
+
generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
172 |
+
|
173 |
+
# Strip the stop sequence if one was triggered
|
174 |
+
if stopping_criteria.triggered_stop_sequence:
|
175 |
+
stop_seq = stopping_criteria.triggered_stop_sequence
|
176 |
+
original_text = generated_text
|
177 |
+
if generated_text.endswith(stop_seq):
|
178 |
+
generated_text = generated_text[:-len(stop_seq)].rstrip()
|
179 |
+
elif stop_seq in generated_text:
|
180 |
+
# Find the last occurrence and remove it and everything after
|
181 |
+
last_pos = generated_text.rfind(stop_seq)
|
182 |
+
if last_pos != -1:
|
183 |
+
generated_text = generated_text[:last_pos].rstrip()
|
184 |
+
|
185 |
+
# Debug output (only show if text was actually modified)
|
186 |
+
if generated_text != original_text:
|
187 |
+
print(f"\n🔍 Stripped stop sequence '{stop_seq}' from response")
|
188 |
+
|
189 |
+
# Add the assistant's response to conversation history
|
190 |
+
self.conversation_history.append({"role": "assistant", "content": generated_text})
|
191 |
+
|
192 |
+
# Return empty string since TextStreamer already printed the response
|
193 |
+
return ""
|
194 |
+
|
195 |
+
except Exception as e:
|
196 |
+
return f"❌ Generation failed: {str(e)}"
|
197 |
+
|
198 |
+
def reset_conversation(self):
|
199 |
+
"""Reset the conversation history."""
|
200 |
+
self.conversation_history = []
|
201 |
+
print("🔄 Conversation history cleared!")
|
202 |
+
|
203 |
+
def show_conversation_history(self):
|
204 |
+
"""Display the current conversation history."""
|
205 |
+
if not self.conversation_history:
|
206 |
+
print("📝 No conversation history yet.")
|
207 |
+
return
|
208 |
+
|
209 |
+
print("\n📝 Conversation History:")
|
210 |
+
print("=" * 50)
|
211 |
+
for i, message in enumerate(self.conversation_history):
|
212 |
+
role = message["role"].capitalize()
|
213 |
+
content = message["content"]
|
214 |
+
print(f"{role}: {content}")
|
215 |
+
if i < len(self.conversation_history) - 1:
|
216 |
+
print("-" * 30)
|
217 |
+
print("=" * 50)
|
218 |
+
|
219 |
+
def interactive_chat(self):
|
220 |
+
"""Main interactive chat loop."""
|
221 |
+
print(f"\n💬 Chatting with {self.model_folder}")
|
222 |
+
print("Commands:")
|
223 |
+
print(" - Type your message to chat")
|
224 |
+
print(" - Type 'quit' or 'exit' to end")
|
225 |
+
print(" - Type 'help' for this message")
|
226 |
+
print(" - Type 'reset' to clear conversation history")
|
227 |
+
print(" - Type 'history' to show conversation history")
|
228 |
+
print(" - Type 'clear' to clear screen")
|
229 |
+
print("\n💡 Start chatting! (Works with any model)")
|
230 |
+
|
231 |
+
while True:
|
232 |
+
try:
|
233 |
+
user_input = input("\n👤 You: ").strip()
|
234 |
+
|
235 |
+
if not user_input:
|
236 |
+
continue
|
237 |
+
|
238 |
+
if user_input.lower() in ['quit', 'exit', 'q']:
|
239 |
+
print("👋 Goodbye!")
|
240 |
+
break
|
241 |
+
|
242 |
+
elif user_input.lower() == 'help':
|
243 |
+
print(f"\n💬 Chatting with {self.model_folder}")
|
244 |
+
print("Commands:")
|
245 |
+
print(" - Type your message to chat")
|
246 |
+
print(" - Type 'quit' or 'exit' to end")
|
247 |
+
print(" - Type 'help' for this message")
|
248 |
+
print(" - Type 'reset' to clear conversation history")
|
249 |
+
print(" - Type 'history' to show conversation history")
|
250 |
+
print(" - Type 'clear' to clear screen")
|
251 |
+
print(" - Works with any model (auto-assigns chat template)")
|
252 |
+
|
253 |
+
elif user_input.lower() == 'reset':
|
254 |
+
self.reset_conversation()
|
255 |
+
|
256 |
+
elif user_input.lower() == 'history':
|
257 |
+
self.show_conversation_history()
|
258 |
+
|
259 |
+
elif user_input.lower() == 'clear':
|
260 |
+
os.system('clear' if os.name == 'posix' else 'cls')
|
261 |
+
|
262 |
+
else:
|
263 |
+
# Generate and display response
|
264 |
+
print(f"\n🤖 {self.model_folder}:")
|
265 |
+
response = self.generate_response(user_input)
|
266 |
+
# No need to print response again - TextStreamer already handled it
|
267 |
+
|
268 |
+
except KeyboardInterrupt:
|
269 |
+
print("\n\n👋 Goodbye!")
|
270 |
+
break
|
271 |
+
except Exception as e:
|
272 |
+
print(f"❌ Error: {str(e)}")
|
273 |
+
|
274 |
+
def main():
|
275 |
+
parser = argparse.ArgumentParser(description="Interactive chat script for any model")
|
276 |
+
parser.add_argument("model_folder", help="Name of the model folder")
|
277 |
+
parser.add_argument("--assistant", action="store_true",
|
278 |
+
help="Force User: Assistant: chat template even if model has its own")
|
279 |
+
|
280 |
+
args = parser.parse_args()
|
281 |
+
|
282 |
+
model_folder = args.model_folder
|
283 |
+
force_assistant_template = args.assistant
|
284 |
+
|
285 |
+
# Check if model folder exists
|
286 |
+
if not os.path.exists(model_folder):
|
287 |
+
print(f"❌ Model folder '{model_folder}' not found!")
|
288 |
+
sys.exit(1)
|
289 |
+
|
290 |
+
# Check if hf subdirectory exists
|
291 |
+
hf_path = os.path.join(model_folder, 'hf')
|
292 |
+
if not os.path.exists(hf_path):
|
293 |
+
print(f"❌ No 'hf' subdirectory found in '{model_folder}'!")
|
294 |
+
sys.exit(1)
|
295 |
+
|
296 |
+
print("🚀 Model Chat Script")
|
297 |
+
print("=" * 50)
|
298 |
+
if force_assistant_template:
|
299 |
+
print("🔧 Forcing User: Assistant: chat template")
|
300 |
+
print("=" * 50)
|
301 |
+
|
302 |
+
chatter = ModelChatter(model_folder, force_assistant_template)
|
303 |
+
|
304 |
+
# Load the model (this will also handle chat template assignment if needed)
|
305 |
+
if not chatter.load_model():
|
306 |
+
print("❌ Failed to load model. Exiting.")
|
307 |
+
sys.exit(1)
|
308 |
+
|
309 |
+
print(f"✅ Model '{model_folder}' loaded successfully")
|
310 |
+
|
311 |
+
# Start interactive chat
|
312 |
+
chatter.interactive_chat()
|
313 |
+
|
314 |
+
if __name__ == "__main__":
|
315 |
+
main()
|
lm_eval.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
TASKS="longbench"
|
4 |
+
|
5 |
+
lm_eval --model vllm --model_args pretrained=./ipt_fineinstructions_all_exp_chat/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --apply_chat_template --batch_size auto --trust_remote_code --confirm_run_unsafe_code --fewshot_as_multiturn --output_path ./output/out.json --limit 10
|
6 |
+
lm_eval --model vllm --model_args pretrained=./ipt_fineinstructions_all_exp_chat/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
|
7 |
+
lm_eval --model vllm --model_args pretrained=./ipt_synthetic_all_exp/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
|
8 |
+
lm_eval --model vllm --model_args pretrained=./ipt_actual_all_exp/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
|
9 |
+
lm_eval --model vllm --model_args pretrained=./ipt_fineinstructions_all_exp/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
|
10 |
+
|
11 |
+
|