33_xx_44_demo / custom_modeling.py
Mahesh2841's picture
Update custom_modeling.py
c70f7b4 verified
import os
import torch
import tensorflow as tf
from transformers import LlamaForCausalLM
from transformers.utils import cached_file
import logging
logger = logging.getLogger(__name__)
class SafeGenerationModel(LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
# Load toxicity model
toxic_path = cached_file(config._name_or_path, "toxic.keras")
if not os.path.exists(toxic_path):
raise FileNotFoundError(f"Toxicity model not found at {toxic_path}")
self.toxicity_model = tf.keras.models.load_model(toxic_path)
self.tokenizer = None
logger.info("Toxicity model loaded successfully")
def is_toxic(self, text, threshold=0.6):
try:
prob = self.toxicity_model.predict([text], verbose=0)[0][0]
return prob > threshold
except Exception as e:
logger.error(f"Toxicity check failed: {str(e)}")
return False
def generate(self, *args, **kwargs):
inputs = kwargs.get("input_ids")
# Check input toxicity
if self.tokenizer and inputs is not None:
input_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True)
if self.is_toxic(input_text):
return self._safe_response()
# Generate response
outputs = super().generate(*args, **kwargs)
# Check output toxicity
if self.tokenizer:
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if self.is_toxic(output_text):
return self._safe_response()
return outputs
def _safe_response(self):
safe_text = "I'm unable to respond to that request. HAHAHA"
return self.tokenizer.encode(safe_text, return_tensors="pt").to(self.device)
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer