|
|
|
""" |
|
Interactive chat script for any model with automatic chat template support. |
|
Usage: python chat_with_models.py <model_folder_name> [--assistant] |
|
""" |
|
|
|
import os |
|
import sys |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList |
|
import warnings |
|
import argparse |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
class StopSequenceCriteria(StoppingCriteria): |
|
def __init__(self, tokenizer, stop_sequences, prompt_length): |
|
self.tokenizer = tokenizer |
|
self.stop_sequences = stop_sequences |
|
self.prompt_length = prompt_length |
|
self.triggered_stop_sequence = None |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
|
if input_ids.shape[1] <= self.prompt_length: |
|
return False |
|
|
|
|
|
new_tokens = input_ids[0][self.prompt_length:] |
|
new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
|
for stop_seq in self.stop_sequences: |
|
if stop_seq in new_text: |
|
return True |
|
return False |
|
|
|
class ModelChatter: |
|
def __init__(self, model_folder, force_assistant_template=False): |
|
self.model_folder = model_folder |
|
self.hf_path = os.path.join(model_folder, 'hf') |
|
self.model = None |
|
self.tokenizer = None |
|
self.pipeline = None |
|
self.conversation_history = [] |
|
self.force_assistant_template = force_assistant_template |
|
|
|
def load_model(self): |
|
"""Load the model and tokenizer.""" |
|
try: |
|
print(f"🔄 Loading {self.model_folder}...") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
if self.force_assistant_template: |
|
print(f"📝 Forcing User: Assistant: chat template...") |
|
custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}""" |
|
self.tokenizer.chat_template = custom_template |
|
print(f"✅ User: Assistant: chat template forced") |
|
elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: |
|
print(f"📝 No chat template found, assigning custom template...") |
|
custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}""" |
|
self.tokenizer.chat_template = custom_template |
|
print(f"✅ Custom chat template assigned") |
|
else: |
|
print(f"✅ Model has existing chat template") |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.hf_path, |
|
device_map=None, |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
self.model.to("cuda:0") |
|
device = "cuda:0" |
|
elif torch.backends.mps.is_available(): |
|
self.model.to("mps") |
|
device = "mps" |
|
else: |
|
self.model.to("cpu") |
|
device = "cpu" |
|
|
|
print(f" 📱 Using device: {device}") |
|
|
|
|
|
self.pipeline = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
print(f" ✅ {self.model_folder} loaded successfully") |
|
return True |
|
|
|
except Exception as e: |
|
print(f" ❌ Failed to load {self.model_folder}: {str(e)}") |
|
return False |
|
|
|
def format_chat_prompt(self, user_message): |
|
"""Format the conversation history and new user message using the chat template.""" |
|
|
|
self.conversation_history.append({"role": "user", "content": user_message}) |
|
|
|
|
|
try: |
|
formatted_prompt = self.tokenizer.apply_chat_template( |
|
self.conversation_history, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
return formatted_prompt |
|
except Exception as e: |
|
print(f"❌ Error formatting chat prompt: {str(e)}") |
|
return None |
|
|
|
def generate_response(self, user_message, max_length=512): |
|
"""Generate a response to the user message.""" |
|
try: |
|
|
|
formatted_prompt = self.format_chat_prompt(user_message) |
|
if formatted_prompt is None: |
|
return "❌ Failed to format chat prompt" |
|
|
|
|
|
print("🤖 Response: ", end="", flush=True) |
|
|
|
|
|
inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to("cuda:0") for k, v in inputs.items()} |
|
elif torch.backends.mps.is_available(): |
|
inputs = {k: v.to("mps") for k, v in inputs.items()} |
|
|
|
|
|
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"] |
|
|
|
|
|
prompt_length = inputs['input_ids'].shape[1] |
|
stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.1, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
streamer=streamer, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
stopping_criteria=StoppingCriteriaList([stopping_criteria]) |
|
) |
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
if stopping_criteria.triggered_stop_sequence: |
|
stop_seq = stopping_criteria.triggered_stop_sequence |
|
original_text = generated_text |
|
if generated_text.endswith(stop_seq): |
|
generated_text = generated_text[:-len(stop_seq)].rstrip() |
|
elif stop_seq in generated_text: |
|
|
|
last_pos = generated_text.rfind(stop_seq) |
|
if last_pos != -1: |
|
generated_text = generated_text[:last_pos].rstrip() |
|
|
|
|
|
if generated_text != original_text: |
|
print(f"\n🔍 Stripped stop sequence '{stop_seq}' from response") |
|
|
|
|
|
self.conversation_history.append({"role": "assistant", "content": generated_text}) |
|
|
|
|
|
return "" |
|
|
|
except Exception as e: |
|
return f"❌ Generation failed: {str(e)}" |
|
|
|
def reset_conversation(self): |
|
"""Reset the conversation history.""" |
|
self.conversation_history = [] |
|
print("🔄 Conversation history cleared!") |
|
|
|
def show_conversation_history(self): |
|
"""Display the current conversation history.""" |
|
if not self.conversation_history: |
|
print("📝 No conversation history yet.") |
|
return |
|
|
|
print("\n📝 Conversation History:") |
|
print("=" * 50) |
|
for i, message in enumerate(self.conversation_history): |
|
role = message["role"].capitalize() |
|
content = message["content"] |
|
print(f"{role}: {content}") |
|
if i < len(self.conversation_history) - 1: |
|
print("-" * 30) |
|
print("=" * 50) |
|
|
|
def interactive_chat(self): |
|
"""Main interactive chat loop.""" |
|
print(f"\n💬 Chatting with {self.model_folder}") |
|
print("Commands:") |
|
print(" - Type your message to chat") |
|
print(" - Type 'quit' or 'exit' to end") |
|
print(" - Type 'help' for this message") |
|
print(" - Type 'reset' to clear conversation history") |
|
print(" - Type 'history' to show conversation history") |
|
print(" - Type 'clear' to clear screen") |
|
print("\n💡 Start chatting! (Works with any model)") |
|
|
|
while True: |
|
try: |
|
user_input = input("\n👤 You: ").strip() |
|
|
|
if not user_input: |
|
continue |
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
print("👋 Goodbye!") |
|
break |
|
|
|
elif user_input.lower() == 'help': |
|
print(f"\n💬 Chatting with {self.model_folder}") |
|
print("Commands:") |
|
print(" - Type your message to chat") |
|
print(" - Type 'quit' or 'exit' to end") |
|
print(" - Type 'help' for this message") |
|
print(" - Type 'reset' to clear conversation history") |
|
print(" - Type 'history' to show conversation history") |
|
print(" - Type 'clear' to clear screen") |
|
print(" - Works with any model (auto-assigns chat template)") |
|
|
|
elif user_input.lower() == 'reset': |
|
self.reset_conversation() |
|
|
|
elif user_input.lower() == 'history': |
|
self.show_conversation_history() |
|
|
|
elif user_input.lower() == 'clear': |
|
os.system('clear' if os.name == 'posix' else 'cls') |
|
|
|
else: |
|
|
|
print(f"\n🤖 {self.model_folder}:") |
|
response = self.generate_response(user_input) |
|
|
|
|
|
except KeyboardInterrupt: |
|
print("\n\n👋 Goodbye!") |
|
break |
|
except Exception as e: |
|
print(f"❌ Error: {str(e)}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Interactive chat script for any model") |
|
parser.add_argument("model_folder", help="Name of the model folder") |
|
parser.add_argument("--assistant", action="store_true", |
|
help="Force User: Assistant: chat template even if model has its own") |
|
|
|
args = parser.parse_args() |
|
|
|
model_folder = args.model_folder |
|
force_assistant_template = args.assistant |
|
|
|
|
|
if not os.path.exists(model_folder): |
|
print(f"❌ Model folder '{model_folder}' not found!") |
|
sys.exit(1) |
|
|
|
|
|
hf_path = os.path.join(model_folder, 'hf') |
|
if not os.path.exists(hf_path): |
|
print(f"❌ No 'hf' subdirectory found in '{model_folder}'!") |
|
sys.exit(1) |
|
|
|
print("🚀 Model Chat Script") |
|
print("=" * 50) |
|
if force_assistant_template: |
|
print("🔧 Forcing User: Assistant: chat template") |
|
print("=" * 50) |
|
|
|
chatter = ModelChatter(model_folder, force_assistant_template) |
|
|
|
|
|
if not chatter.load_model(): |
|
print("❌ Failed to load model. Exiting.") |
|
sys.exit(1) |
|
|
|
print(f"✅ Model '{model_folder}' loaded successfully") |
|
|
|
|
|
chatter.interactive_chat() |
|
|
|
if __name__ == "__main__": |
|
main() |