AjayP13 commited on
Commit
f69ab14
·
1 Parent(s): dbe1cc1

Update files

Browse files
Files changed (2) hide show
  1. chat_with_models.py +315 -0
  2. lm_eval.sh +11 -0
chat_with_models.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Interactive chat script for any model with automatic chat template support.
4
+ Usage: python chat_with_models.py <model_folder_name> [--assistant]
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList
11
+ import warnings
12
+ import argparse
13
+
14
+ # Suppress warnings for cleaner output
15
+ warnings.filterwarnings("ignore")
16
+
17
+ class StopSequenceCriteria(StoppingCriteria):
18
+ def __init__(self, tokenizer, stop_sequences, prompt_length):
19
+ self.tokenizer = tokenizer
20
+ self.stop_sequences = stop_sequences
21
+ self.prompt_length = prompt_length
22
+ self.triggered_stop_sequence = None
23
+
24
+ def __call__(self, input_ids, scores, **kwargs):
25
+ # Only check the newly generated part (after the prompt)
26
+ if input_ids.shape[1] <= self.prompt_length:
27
+ return False
28
+
29
+ # Decode only the newly generated tokens
30
+ new_tokens = input_ids[0][self.prompt_length:]
31
+ new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
32
+
33
+ # Check if any stop sequence appears in the newly generated text
34
+ for stop_seq in self.stop_sequences:
35
+ if stop_seq in new_text:
36
+ return True
37
+ return False
38
+
39
+ class ModelChatter:
40
+ def __init__(self, model_folder, force_assistant_template=False):
41
+ self.model_folder = model_folder
42
+ self.hf_path = os.path.join(model_folder, 'hf')
43
+ self.model = None
44
+ self.tokenizer = None
45
+ self.pipeline = None
46
+ self.conversation_history = []
47
+ self.force_assistant_template = force_assistant_template
48
+
49
+ def load_model(self):
50
+ """Load the model and tokenizer."""
51
+ try:
52
+ print(f"🔄 Loading {self.model_folder}...")
53
+
54
+ # Load tokenizer
55
+ self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path)
56
+ if self.tokenizer.pad_token is None:
57
+ self.tokenizer.pad_token = self.tokenizer.eos_token
58
+
59
+ # Handle chat template assignment
60
+ if self.force_assistant_template:
61
+ print(f"📝 Forcing User: Assistant: chat template...")
62
+ custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}"""
63
+ self.tokenizer.chat_template = custom_template
64
+ print(f"✅ User: Assistant: chat template forced")
65
+ elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
66
+ print(f"📝 No chat template found, assigning custom template...")
67
+ custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}"""
68
+ self.tokenizer.chat_template = custom_template
69
+ print(f"✅ Custom chat template assigned")
70
+ else:
71
+ print(f"✅ Model has existing chat template")
72
+
73
+ # Load model
74
+ self.model = AutoModelForCausalLM.from_pretrained(
75
+ self.hf_path,
76
+ device_map=None,
77
+ torch_dtype=torch.float16,
78
+ trust_remote_code=True
79
+ )
80
+
81
+ # Move to appropriate device
82
+ if torch.cuda.is_available():
83
+ self.model.to("cuda:0")
84
+ device = "cuda:0"
85
+ elif torch.backends.mps.is_available():
86
+ self.model.to("mps")
87
+ device = "mps"
88
+ else:
89
+ self.model.to("cpu")
90
+ device = "cpu"
91
+
92
+ print(f" 📱 Using device: {device}")
93
+
94
+ # Create pipeline
95
+ self.pipeline = pipeline(
96
+ "text-generation",
97
+ model=self.model,
98
+ tokenizer=self.tokenizer,
99
+ device_map="auto",
100
+ torch_dtype=torch.float16
101
+ )
102
+
103
+ print(f" ✅ {self.model_folder} loaded successfully")
104
+ return True
105
+
106
+ except Exception as e:
107
+ print(f" ❌ Failed to load {self.model_folder}: {str(e)}")
108
+ return False
109
+
110
+ def format_chat_prompt(self, user_message):
111
+ """Format the conversation history and new user message using the chat template."""
112
+ # Add the new user message to conversation history
113
+ self.conversation_history.append({"role": "user", "content": user_message})
114
+
115
+ # Format using the tokenizer's chat template
116
+ try:
117
+ formatted_prompt = self.tokenizer.apply_chat_template(
118
+ self.conversation_history,
119
+ tokenize=False,
120
+ add_generation_prompt=True
121
+ )
122
+ return formatted_prompt
123
+ except Exception as e:
124
+ print(f"❌ Error formatting chat prompt: {str(e)}")
125
+ return None
126
+
127
+ def generate_response(self, user_message, max_length=512):
128
+ """Generate a response to the user message."""
129
+ try:
130
+ # Format the chat prompt
131
+ formatted_prompt = self.format_chat_prompt(user_message)
132
+ if formatted_prompt is None:
133
+ return "❌ Failed to format chat prompt"
134
+
135
+ # Generate response with streaming
136
+ print("🤖 Response: ", end="", flush=True)
137
+
138
+ # Use the model directly for streaming with TextStreamer
139
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
140
+ if torch.cuda.is_available():
141
+ inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
142
+ elif torch.backends.mps.is_available():
143
+ inputs = {k: v.to("mps") for k, v in inputs.items()}
144
+
145
+ # Create a streamer that prints tokens as they're generated
146
+ streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
147
+
148
+ # Define stop sequences
149
+ stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"]
150
+
151
+ # Create stopping criteria
152
+ prompt_length = inputs['input_ids'].shape[1]
153
+ stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length)
154
+
155
+ # Generate with streaming
156
+ with torch.no_grad():
157
+ outputs = self.model.generate(
158
+ **inputs,
159
+ max_new_tokens=max_length,
160
+ do_sample=True,
161
+ temperature=0.7,
162
+ top_p=0.9,
163
+ repetition_penalty=1.1,
164
+ pad_token_id=self.tokenizer.eos_token_id,
165
+ streamer=streamer,
166
+ eos_token_id=self.tokenizer.eos_token_id,
167
+ stopping_criteria=StoppingCriteriaList([stopping_criteria])
168
+ )
169
+
170
+ # Decode the full response for conversation history
171
+ generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
172
+
173
+ # Strip the stop sequence if one was triggered
174
+ if stopping_criteria.triggered_stop_sequence:
175
+ stop_seq = stopping_criteria.triggered_stop_sequence
176
+ original_text = generated_text
177
+ if generated_text.endswith(stop_seq):
178
+ generated_text = generated_text[:-len(stop_seq)].rstrip()
179
+ elif stop_seq in generated_text:
180
+ # Find the last occurrence and remove it and everything after
181
+ last_pos = generated_text.rfind(stop_seq)
182
+ if last_pos != -1:
183
+ generated_text = generated_text[:last_pos].rstrip()
184
+
185
+ # Debug output (only show if text was actually modified)
186
+ if generated_text != original_text:
187
+ print(f"\n🔍 Stripped stop sequence '{stop_seq}' from response")
188
+
189
+ # Add the assistant's response to conversation history
190
+ self.conversation_history.append({"role": "assistant", "content": generated_text})
191
+
192
+ # Return empty string since TextStreamer already printed the response
193
+ return ""
194
+
195
+ except Exception as e:
196
+ return f"❌ Generation failed: {str(e)}"
197
+
198
+ def reset_conversation(self):
199
+ """Reset the conversation history."""
200
+ self.conversation_history = []
201
+ print("🔄 Conversation history cleared!")
202
+
203
+ def show_conversation_history(self):
204
+ """Display the current conversation history."""
205
+ if not self.conversation_history:
206
+ print("📝 No conversation history yet.")
207
+ return
208
+
209
+ print("\n📝 Conversation History:")
210
+ print("=" * 50)
211
+ for i, message in enumerate(self.conversation_history):
212
+ role = message["role"].capitalize()
213
+ content = message["content"]
214
+ print(f"{role}: {content}")
215
+ if i < len(self.conversation_history) - 1:
216
+ print("-" * 30)
217
+ print("=" * 50)
218
+
219
+ def interactive_chat(self):
220
+ """Main interactive chat loop."""
221
+ print(f"\n💬 Chatting with {self.model_folder}")
222
+ print("Commands:")
223
+ print(" - Type your message to chat")
224
+ print(" - Type 'quit' or 'exit' to end")
225
+ print(" - Type 'help' for this message")
226
+ print(" - Type 'reset' to clear conversation history")
227
+ print(" - Type 'history' to show conversation history")
228
+ print(" - Type 'clear' to clear screen")
229
+ print("\n💡 Start chatting! (Works with any model)")
230
+
231
+ while True:
232
+ try:
233
+ user_input = input("\n👤 You: ").strip()
234
+
235
+ if not user_input:
236
+ continue
237
+
238
+ if user_input.lower() in ['quit', 'exit', 'q']:
239
+ print("👋 Goodbye!")
240
+ break
241
+
242
+ elif user_input.lower() == 'help':
243
+ print(f"\n💬 Chatting with {self.model_folder}")
244
+ print("Commands:")
245
+ print(" - Type your message to chat")
246
+ print(" - Type 'quit' or 'exit' to end")
247
+ print(" - Type 'help' for this message")
248
+ print(" - Type 'reset' to clear conversation history")
249
+ print(" - Type 'history' to show conversation history")
250
+ print(" - Type 'clear' to clear screen")
251
+ print(" - Works with any model (auto-assigns chat template)")
252
+
253
+ elif user_input.lower() == 'reset':
254
+ self.reset_conversation()
255
+
256
+ elif user_input.lower() == 'history':
257
+ self.show_conversation_history()
258
+
259
+ elif user_input.lower() == 'clear':
260
+ os.system('clear' if os.name == 'posix' else 'cls')
261
+
262
+ else:
263
+ # Generate and display response
264
+ print(f"\n🤖 {self.model_folder}:")
265
+ response = self.generate_response(user_input)
266
+ # No need to print response again - TextStreamer already handled it
267
+
268
+ except KeyboardInterrupt:
269
+ print("\n\n👋 Goodbye!")
270
+ break
271
+ except Exception as e:
272
+ print(f"❌ Error: {str(e)}")
273
+
274
+ def main():
275
+ parser = argparse.ArgumentParser(description="Interactive chat script for any model")
276
+ parser.add_argument("model_folder", help="Name of the model folder")
277
+ parser.add_argument("--assistant", action="store_true",
278
+ help="Force User: Assistant: chat template even if model has its own")
279
+
280
+ args = parser.parse_args()
281
+
282
+ model_folder = args.model_folder
283
+ force_assistant_template = args.assistant
284
+
285
+ # Check if model folder exists
286
+ if not os.path.exists(model_folder):
287
+ print(f"❌ Model folder '{model_folder}' not found!")
288
+ sys.exit(1)
289
+
290
+ # Check if hf subdirectory exists
291
+ hf_path = os.path.join(model_folder, 'hf')
292
+ if not os.path.exists(hf_path):
293
+ print(f"❌ No 'hf' subdirectory found in '{model_folder}'!")
294
+ sys.exit(1)
295
+
296
+ print("🚀 Model Chat Script")
297
+ print("=" * 50)
298
+ if force_assistant_template:
299
+ print("🔧 Forcing User: Assistant: chat template")
300
+ print("=" * 50)
301
+
302
+ chatter = ModelChatter(model_folder, force_assistant_template)
303
+
304
+ # Load the model (this will also handle chat template assignment if needed)
305
+ if not chatter.load_model():
306
+ print("❌ Failed to load model. Exiting.")
307
+ sys.exit(1)
308
+
309
+ print(f"✅ Model '{model_folder}' loaded successfully")
310
+
311
+ # Start interactive chat
312
+ chatter.interactive_chat()
313
+
314
+ if __name__ == "__main__":
315
+ main()
lm_eval.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ TASKS="longbench"
4
+
5
+ lm_eval --model vllm --model_args pretrained=./ipt_fineinstructions_all_exp_chat/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --apply_chat_template --batch_size auto --trust_remote_code --confirm_run_unsafe_code --fewshot_as_multiturn --output_path ./output/out.json --limit 10
6
+ lm_eval --model vllm --model_args pretrained=./ipt_fineinstructions_all_exp_chat/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
7
+ lm_eval --model vllm --model_args pretrained=./ipt_synthetic_all_exp/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
8
+ lm_eval --model vllm --model_args pretrained=./ipt_actual_all_exp/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
9
+ lm_eval --model vllm --model_args pretrained=./ipt_fineinstructions_all_exp/hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.7 --tasks $TASKS --device cuda:0 --batch_size auto --trust_remote_code --confirm_run_unsafe_code --output_path ./output/out.json --limit 10
10
+
11
+