File size: 1,945 Bytes
0a43307
 
 
 
 
 
 
 
 
c70f7b4
 
 
 
 
 
 
 
 
 
 
0a43307
c70f7b4
0a43307
 
c70f7b4
0a43307
c70f7b4
 
0a43307
 
aebf588
0a43307
c70f7b4
 
 
0a43307
c70f7b4
 
aebf588
c70f7b4
aebf588
 
c70f7b4
6be4f86
aebf588
c70f7b4
 
aebf588
 
c70f7b4
 
 
 
 
aebf588
c70f7b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import torch
import tensorflow as tf
from transformers import LlamaForCausalLM
from transformers.utils import cached_file
import logging

logger = logging.getLogger(__name__)

class SafeGenerationModel(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        # Load toxicity model
        toxic_path = cached_file(config._name_or_path, "toxic.keras")
        if not os.path.exists(toxic_path):
            raise FileNotFoundError(f"Toxicity model not found at {toxic_path}")
        
        self.toxicity_model = tf.keras.models.load_model(toxic_path)
        self.tokenizer = None
        logger.info("Toxicity model loaded successfully")

    def is_toxic(self, text, threshold=0.6):
        try:
            prob = self.toxicity_model.predict([text], verbose=0)[0][0]
            return prob > threshold
        except Exception as e:
            logger.error(f"Toxicity check failed: {str(e)}")
            return False

    def generate(self, *args, **kwargs):
        inputs = kwargs.get("input_ids")
        
        # Check input toxicity
        if self.tokenizer and inputs is not None:
            input_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True)
            if self.is_toxic(input_text):
                return self._safe_response()
        
        # Generate response
        outputs = super().generate(*args, **kwargs)
        
        # Check output toxicity
        if self.tokenizer:
            output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            if self.is_toxic(output_text):
                return self._safe_response()
        
        return outputs

    def _safe_response(self):
        safe_text = "I'm unable to respond to that request. HAHAHA"
        return self.tokenizer.encode(safe_text, return_tensors="pt").to(self.device)

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer