#!/usr/bin/env python3 """ Interactive chat script for any model with automatic chat template support. Usage: python chat_with_models.py [--assistant] """ import os import sys import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList import warnings import argparse # Suppress warnings for cleaner output warnings.filterwarnings("ignore") class StopSequenceCriteria(StoppingCriteria): def __init__(self, tokenizer, stop_sequences, prompt_length): self.tokenizer = tokenizer self.stop_sequences = stop_sequences self.prompt_length = prompt_length self.triggered_stop_sequence = None def __call__(self, input_ids, scores, **kwargs): # Only check the newly generated part (after the prompt) if input_ids.shape[1] <= self.prompt_length: return False # Decode only the newly generated tokens new_tokens = input_ids[0][self.prompt_length:] new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) # Check if any stop sequence appears in the newly generated text for stop_seq in self.stop_sequences: if stop_seq in new_text: return True return False class ModelChatter: def __init__(self, model_folder, force_assistant_template=False): self.model_folder = model_folder self.hf_path = os.path.join(model_folder, 'hf') self.model = None self.tokenizer = None self.pipeline = None self.conversation_history = [] self.force_assistant_template = force_assistant_template def load_model(self): """Load the model and tokenizer.""" try: print(f"šŸ”„ Loading {self.model_folder}...") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Handle chat template assignment if self.force_assistant_template: print(f"šŸ“ Forcing User: Assistant: chat template...") custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}""" self.tokenizer.chat_template = custom_template print(f"āœ… User: Assistant: chat template forced") elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: print(f"šŸ“ No chat template found, assigning custom template...") custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}""" self.tokenizer.chat_template = custom_template print(f"āœ… Custom chat template assigned") else: print(f"āœ… Model has existing chat template") # Load model self.model = AutoModelForCausalLM.from_pretrained( self.hf_path, device_map=None, torch_dtype=torch.float16, trust_remote_code=True ) # Move to appropriate device if torch.cuda.is_available(): self.model.to("cuda:0") device = "cuda:0" elif torch.backends.mps.is_available(): self.model.to("mps") device = "mps" else: self.model.to("cpu") device = "cpu" print(f" šŸ“± Using device: {device}") # Create pipeline self.pipeline = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device_map="auto", torch_dtype=torch.float16 ) print(f" āœ… {self.model_folder} loaded successfully") return True except Exception as e: print(f" āŒ Failed to load {self.model_folder}: {str(e)}") return False def format_chat_prompt(self, user_message): """Format the conversation history and new user message using the chat template.""" # Add the new user message to conversation history self.conversation_history.append({"role": "user", "content": user_message}) # Format using the tokenizer's chat template try: formatted_prompt = self.tokenizer.apply_chat_template( self.conversation_history, tokenize=False, add_generation_prompt=True ) return formatted_prompt except Exception as e: print(f"āŒ Error formatting chat prompt: {str(e)}") return None def generate_response(self, user_message, max_length=512): """Generate a response to the user message.""" try: # Format the chat prompt formatted_prompt = self.format_chat_prompt(user_message) if formatted_prompt is None: return "āŒ Failed to format chat prompt" # Generate response with streaming print("šŸ¤– Response: ", end="", flush=True) # Use the model directly for streaming with TextStreamer inputs = self.tokenizer(formatted_prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda:0") for k, v in inputs.items()} elif torch.backends.mps.is_available(): inputs = {k: v.to("mps") for k, v in inputs.items()} # Create a streamer that prints tokens as they're generated streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # Define stop sequences stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"] # Create stopping criteria prompt_length = inputs['input_ids'].shape[1] stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length) # Generate with streaming with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_length, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=self.tokenizer.eos_token_id, streamer=streamer, eos_token_id=self.tokenizer.eos_token_id, stopping_criteria=StoppingCriteriaList([stopping_criteria]) ) # Decode the full response for conversation history generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) # Strip the stop sequence if one was triggered if stopping_criteria.triggered_stop_sequence: stop_seq = stopping_criteria.triggered_stop_sequence original_text = generated_text if generated_text.endswith(stop_seq): generated_text = generated_text[:-len(stop_seq)].rstrip() elif stop_seq in generated_text: # Find the last occurrence and remove it and everything after last_pos = generated_text.rfind(stop_seq) if last_pos != -1: generated_text = generated_text[:last_pos].rstrip() # Debug output (only show if text was actually modified) if generated_text != original_text: print(f"\nšŸ” Stripped stop sequence '{stop_seq}' from response") # Add the assistant's response to conversation history self.conversation_history.append({"role": "assistant", "content": generated_text}) # Return empty string since TextStreamer already printed the response return "" except Exception as e: return f"āŒ Generation failed: {str(e)}" def reset_conversation(self): """Reset the conversation history.""" self.conversation_history = [] print("šŸ”„ Conversation history cleared!") def show_conversation_history(self): """Display the current conversation history.""" if not self.conversation_history: print("šŸ“ No conversation history yet.") return print("\nšŸ“ Conversation History:") print("=" * 50) for i, message in enumerate(self.conversation_history): role = message["role"].capitalize() content = message["content"] print(f"{role}: {content}") if i < len(self.conversation_history) - 1: print("-" * 30) print("=" * 50) def interactive_chat(self): """Main interactive chat loop.""" print(f"\nšŸ’¬ Chatting with {self.model_folder}") print("Commands:") print(" - Type your message to chat") print(" - Type 'quit' or 'exit' to end") print(" - Type 'help' for this message") print(" - Type 'reset' to clear conversation history") print(" - Type 'history' to show conversation history") print(" - Type 'clear' to clear screen") print("\nšŸ’” Start chatting! (Works with any model)") while True: try: user_input = input("\nšŸ‘¤ You: ").strip() if not user_input: continue if user_input.lower() in ['quit', 'exit', 'q']: print("šŸ‘‹ Goodbye!") break elif user_input.lower() == 'help': print(f"\nšŸ’¬ Chatting with {self.model_folder}") print("Commands:") print(" - Type your message to chat") print(" - Type 'quit' or 'exit' to end") print(" - Type 'help' for this message") print(" - Type 'reset' to clear conversation history") print(" - Type 'history' to show conversation history") print(" - Type 'clear' to clear screen") print(" - Works with any model (auto-assigns chat template)") elif user_input.lower() == 'reset': self.reset_conversation() elif user_input.lower() == 'history': self.show_conversation_history() elif user_input.lower() == 'clear': os.system('clear' if os.name == 'posix' else 'cls') else: # Generate and display response print(f"\nšŸ¤– {self.model_folder}:") response = self.generate_response(user_input) # No need to print response again - TextStreamer already handled it except KeyboardInterrupt: print("\n\nšŸ‘‹ Goodbye!") break except Exception as e: print(f"āŒ Error: {str(e)}") def main(): parser = argparse.ArgumentParser(description="Interactive chat script for any model") parser.add_argument("model_folder", help="Name of the model folder") parser.add_argument("--assistant", action="store_true", help="Force User: Assistant: chat template even if model has its own") args = parser.parse_args() model_folder = args.model_folder force_assistant_template = args.assistant # Check if model folder exists if not os.path.exists(model_folder): print(f"āŒ Model folder '{model_folder}' not found!") sys.exit(1) # Check if hf subdirectory exists hf_path = os.path.join(model_folder, 'hf') if not os.path.exists(hf_path): print(f"āŒ No 'hf' subdirectory found in '{model_folder}'!") sys.exit(1) print("šŸš€ Model Chat Script") print("=" * 50) if force_assistant_template: print("šŸ”§ Forcing User: Assistant: chat template") print("=" * 50) chatter = ModelChatter(model_folder, force_assistant_template) # Load the model (this will also handle chat template assignment if needed) if not chatter.load_model(): print("āŒ Failed to load model. Exiting.") sys.exit(1) print(f"āœ… Model '{model_folder}' loaded successfully") # Start interactive chat chatter.interactive_chat() if __name__ == "__main__": main()