|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from transformers import LlamaForCausalLM |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from dataclasses import dataclass |
|
|
|
@dataclass |
|
class SelfCorrectiveLlamaOutput(CausalLMOutputWithPast): |
|
hallucination_logits: torch.FloatTensor = None |
|
|
|
class SelfCorrectiveLlama(LlamaForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.num_new_tokens = 3 |
|
self.original_vocab_size = config.vocab_size |
|
|
|
|
|
self.new_token_embeddings = nn.Embedding(self.num_new_tokens, config.hidden_size) |
|
|
|
|
|
with torch.no_grad(): |
|
original_embeddings = self.model.embed_tokens.weight |
|
mean_embeddings = original_embeddings.mean(dim=0) |
|
self.new_token_embeddings.weight.data.copy_( |
|
mean_embeddings.unsqueeze(0).expand(self.num_new_tokens, -1) |
|
) |
|
|
|
intermediate_size = config.intermediate_size |
|
self.hallucination_gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) |
|
self.hallucination_up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) |
|
self.hallucination_down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False) |
|
self.hallucination_detector = nn.Linear(config.hidden_size, self.num_new_tokens + 1) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
labels=None, |
|
hallucination_labels=None, |
|
**kwargs |
|
): |
|
|
|
clamped_input_ids = torch.clamp(input_ids, max=self.original_vocab_size - 1) |
|
inputs_embeds = self.model.embed_tokens(clamped_input_ids) |
|
|
|
|
|
special_token_mask = input_ids >= self.original_vocab_size |
|
if special_token_mask.any(): |
|
special_ids = input_ids[special_token_mask] - self.original_vocab_size |
|
special_embeds = self.new_token_embeddings(special_ids) |
|
inputs_embeds[special_token_mask] = special_embeds |
|
|
|
|
|
kwargs["inputs_embeds"] = inputs_embeds |
|
transformer_outputs = self.model( |
|
attention_mask=attention_mask, |
|
**kwargs |
|
) |
|
last_hidden = transformer_outputs.last_hidden_state |
|
|
|
|
|
|
|
main_logits = self.lm_head(last_hidden) |
|
|
|
|
|
new_logits = F.linear(last_hidden, self.new_token_embeddings.weight) |
|
|
|
|
|
logits = torch.cat([main_logits, new_logits], dim=-1) |
|
|
|
|
|
gate_output = self.hallucination_gate_proj(last_hidden) |
|
up_output = self.hallucination_up_proj(last_hidden) |
|
gated_hidden = F.silu(gate_output) * up_output |
|
detector_hidden = self.hallucination_down_proj(gated_hidden) |
|
|
|
|
|
all_hallucination_logits = self.hallucination_detector(detector_hidden) |
|
|
|
|
|
deletion_logits = all_hallucination_logits[..., 1:] |
|
|
|
|
|
if hallucination_labels is not None and labels is not None: |
|
|
|
|
|
mask_no_hallucination = (hallucination_labels == 0) |
|
|
|
|
|
|
|
vocab_size = logits.shape[-1] |
|
mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size) |
|
|
|
|
|
combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1) |
|
to_add = torch.where( |
|
combined_mask, |
|
deletion_logits, |
|
torch.zeros_like(deletion_logits) |
|
) |
|
logits[:, :, -self.num_new_tokens:].add_(to_add) |
|
else: |
|
|
|
logits[:, :, -self.num_new_tokens:].add_(deletion_logits) |
|
|
|
|
|
return SelfCorrectiveLlamaOutput( |
|
loss=None, |
|
logits=logits, |
|
hallucination_logits=all_hallucination_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=None, |
|
attentions=transformer_outputs.attentions |
|
) |