pretraining_experiments / chat_with_models.py
AjayP13's picture
Update files
f69ab14
#!/usr/bin/env python3
"""
Interactive chat script for any model with automatic chat template support.
Usage: python chat_with_models.py <model_folder_name> [--assistant]
"""
import os
import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList
import warnings
import argparse
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
class StopSequenceCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_sequences, prompt_length):
self.tokenizer = tokenizer
self.stop_sequences = stop_sequences
self.prompt_length = prompt_length
self.triggered_stop_sequence = None
def __call__(self, input_ids, scores, **kwargs):
# Only check the newly generated part (after the prompt)
if input_ids.shape[1] <= self.prompt_length:
return False
# Decode only the newly generated tokens
new_tokens = input_ids[0][self.prompt_length:]
new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Check if any stop sequence appears in the newly generated text
for stop_seq in self.stop_sequences:
if stop_seq in new_text:
return True
return False
class ModelChatter:
def __init__(self, model_folder, force_assistant_template=False):
self.model_folder = model_folder
self.hf_path = os.path.join(model_folder, 'hf')
self.model = None
self.tokenizer = None
self.pipeline = None
self.conversation_history = []
self.force_assistant_template = force_assistant_template
def load_model(self):
"""Load the model and tokenizer."""
try:
print(f"🔄 Loading {self.model_folder}...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Handle chat template assignment
if self.force_assistant_template:
print(f"📝 Forcing User: Assistant: chat template...")
custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}"""
self.tokenizer.chat_template = custom_template
print(f"✅ User: Assistant: chat template forced")
elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
print(f"📝 No chat template found, assigning custom template...")
custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}"""
self.tokenizer.chat_template = custom_template
print(f"✅ Custom chat template assigned")
else:
print(f"✅ Model has existing chat template")
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
self.hf_path,
device_map=None,
torch_dtype=torch.float16,
trust_remote_code=True
)
# Move to appropriate device
if torch.cuda.is_available():
self.model.to("cuda:0")
device = "cuda:0"
elif torch.backends.mps.is_available():
self.model.to("mps")
device = "mps"
else:
self.model.to("cpu")
device = "cpu"
print(f" 📱 Using device: {device}")
# Create pipeline
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device_map="auto",
torch_dtype=torch.float16
)
print(f" ✅ {self.model_folder} loaded successfully")
return True
except Exception as e:
print(f" ❌ Failed to load {self.model_folder}: {str(e)}")
return False
def format_chat_prompt(self, user_message):
"""Format the conversation history and new user message using the chat template."""
# Add the new user message to conversation history
self.conversation_history.append({"role": "user", "content": user_message})
# Format using the tokenizer's chat template
try:
formatted_prompt = self.tokenizer.apply_chat_template(
self.conversation_history,
tokenize=False,
add_generation_prompt=True
)
return formatted_prompt
except Exception as e:
print(f"❌ Error formatting chat prompt: {str(e)}")
return None
def generate_response(self, user_message, max_length=512):
"""Generate a response to the user message."""
try:
# Format the chat prompt
formatted_prompt = self.format_chat_prompt(user_message)
if formatted_prompt is None:
return "❌ Failed to format chat prompt"
# Generate response with streaming
print("🤖 Response: ", end="", flush=True)
# Use the model directly for streaming with TextStreamer
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
elif torch.backends.mps.is_available():
inputs = {k: v.to("mps") for k, v in inputs.items()}
# Create a streamer that prints tokens as they're generated
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Define stop sequences
stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"]
# Create stopping criteria
prompt_length = inputs['input_ids'].shape[1]
stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length)
# Generate with streaming
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.eos_token_id,
streamer=streamer,
eos_token_id=self.tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([stopping_criteria])
)
# Decode the full response for conversation history
generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Strip the stop sequence if one was triggered
if stopping_criteria.triggered_stop_sequence:
stop_seq = stopping_criteria.triggered_stop_sequence
original_text = generated_text
if generated_text.endswith(stop_seq):
generated_text = generated_text[:-len(stop_seq)].rstrip()
elif stop_seq in generated_text:
# Find the last occurrence and remove it and everything after
last_pos = generated_text.rfind(stop_seq)
if last_pos != -1:
generated_text = generated_text[:last_pos].rstrip()
# Debug output (only show if text was actually modified)
if generated_text != original_text:
print(f"\n🔍 Stripped stop sequence '{stop_seq}' from response")
# Add the assistant's response to conversation history
self.conversation_history.append({"role": "assistant", "content": generated_text})
# Return empty string since TextStreamer already printed the response
return ""
except Exception as e:
return f"❌ Generation failed: {str(e)}"
def reset_conversation(self):
"""Reset the conversation history."""
self.conversation_history = []
print("🔄 Conversation history cleared!")
def show_conversation_history(self):
"""Display the current conversation history."""
if not self.conversation_history:
print("📝 No conversation history yet.")
return
print("\n📝 Conversation History:")
print("=" * 50)
for i, message in enumerate(self.conversation_history):
role = message["role"].capitalize()
content = message["content"]
print(f"{role}: {content}")
if i < len(self.conversation_history) - 1:
print("-" * 30)
print("=" * 50)
def interactive_chat(self):
"""Main interactive chat loop."""
print(f"\n💬 Chatting with {self.model_folder}")
print("Commands:")
print(" - Type your message to chat")
print(" - Type 'quit' or 'exit' to end")
print(" - Type 'help' for this message")
print(" - Type 'reset' to clear conversation history")
print(" - Type 'history' to show conversation history")
print(" - Type 'clear' to clear screen")
print("\n💡 Start chatting! (Works with any model)")
while True:
try:
user_input = input("\n👤 You: ").strip()
if not user_input:
continue
if user_input.lower() in ['quit', 'exit', 'q']:
print("👋 Goodbye!")
break
elif user_input.lower() == 'help':
print(f"\n💬 Chatting with {self.model_folder}")
print("Commands:")
print(" - Type your message to chat")
print(" - Type 'quit' or 'exit' to end")
print(" - Type 'help' for this message")
print(" - Type 'reset' to clear conversation history")
print(" - Type 'history' to show conversation history")
print(" - Type 'clear' to clear screen")
print(" - Works with any model (auto-assigns chat template)")
elif user_input.lower() == 'reset':
self.reset_conversation()
elif user_input.lower() == 'history':
self.show_conversation_history()
elif user_input.lower() == 'clear':
os.system('clear' if os.name == 'posix' else 'cls')
else:
# Generate and display response
print(f"\n🤖 {self.model_folder}:")
response = self.generate_response(user_input)
# No need to print response again - TextStreamer already handled it
except KeyboardInterrupt:
print("\n\n👋 Goodbye!")
break
except Exception as e:
print(f"❌ Error: {str(e)}")
def main():
parser = argparse.ArgumentParser(description="Interactive chat script for any model")
parser.add_argument("model_folder", help="Name of the model folder")
parser.add_argument("--assistant", action="store_true",
help="Force User: Assistant: chat template even if model has its own")
args = parser.parse_args()
model_folder = args.model_folder
force_assistant_template = args.assistant
# Check if model folder exists
if not os.path.exists(model_folder):
print(f"❌ Model folder '{model_folder}' not found!")
sys.exit(1)
# Check if hf subdirectory exists
hf_path = os.path.join(model_folder, 'hf')
if not os.path.exists(hf_path):
print(f"❌ No 'hf' subdirectory found in '{model_folder}'!")
sys.exit(1)
print("🚀 Model Chat Script")
print("=" * 50)
if force_assistant_template:
print("🔧 Forcing User: Assistant: chat template")
print("=" * 50)
chatter = ModelChatter(model_folder, force_assistant_template)
# Load the model (this will also handle chat template assignment if needed)
if not chatter.load_model():
print("❌ Failed to load model. Exiting.")
sys.exit(1)
print(f"✅ Model '{model_folder}' loaded successfully")
# Start interactive chat
chatter.interactive_chat()
if __name__ == "__main__":
main()