import os import torch import tensorflow as tf from transformers import LlamaForCausalLM from transformers.utils import cached_file import logging logger = logging.getLogger(__name__) class SafeGenerationModel(LlamaForCausalLM): def __init__(self, config): super().__init__(config) # Load toxicity model toxic_path = cached_file(config._name_or_path, "toxic.keras") if not os.path.exists(toxic_path): raise FileNotFoundError(f"Toxicity model not found at {toxic_path}") self.toxicity_model = tf.keras.models.load_model(toxic_path) self.tokenizer = None logger.info("Toxicity model loaded successfully") def is_toxic(self, text, threshold=0.6): try: prob = self.toxicity_model.predict([text], verbose=0)[0][0] return prob > threshold except Exception as e: logger.error(f"Toxicity check failed: {str(e)}") return False def generate(self, *args, **kwargs): inputs = kwargs.get("input_ids") # Check input toxicity if self.tokenizer and inputs is not None: input_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True) if self.is_toxic(input_text): return self._safe_response() # Generate response outputs = super().generate(*args, **kwargs) # Check output toxicity if self.tokenizer: output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if self.is_toxic(output_text): return self._safe_response() return outputs def _safe_response(self): safe_text = "I'm unable to respond to that request. HAHAHA" return self.tokenizer.encode(safe_text, return_tensors="pt").to(self.device) def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer