|
import os |
|
import torch |
|
import tensorflow as tf |
|
from transformers import LlamaForCausalLM |
|
from transformers.utils import cached_file |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class SafeGenerationModel(LlamaForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
toxic_path = cached_file(config._name_or_path, "toxic.keras") |
|
if not os.path.exists(toxic_path): |
|
raise FileNotFoundError(f"Toxicity model not found at {toxic_path}") |
|
|
|
self.toxicity_model = tf.keras.models.load_model(toxic_path) |
|
self.tokenizer = None |
|
logger.info("Toxicity model loaded successfully") |
|
|
|
def is_toxic(self, text, threshold=0.6): |
|
try: |
|
prob = self.toxicity_model.predict([text], verbose=0)[0][0] |
|
return prob > threshold |
|
except Exception as e: |
|
logger.error(f"Toxicity check failed: {str(e)}") |
|
return False |
|
|
|
def generate(self, *args, **kwargs): |
|
inputs = kwargs.get("input_ids") |
|
|
|
|
|
if self.tokenizer and inputs is not None: |
|
input_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True) |
|
if self.is_toxic(input_text): |
|
return self._safe_response() |
|
|
|
|
|
outputs = super().generate(*args, **kwargs) |
|
|
|
|
|
if self.tokenizer: |
|
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
if self.is_toxic(output_text): |
|
return self._safe_response() |
|
|
|
return outputs |
|
|
|
def _safe_response(self): |
|
safe_text = "I'm unable to respond to that request. HAHAHA" |
|
return self.tokenizer.encode(safe_text, return_tensors="pt").to(self.device) |
|
|
|
def set_tokenizer(self, tokenizer): |
|
self.tokenizer = tokenizer |