updated forward pass to be more memeory efficient
Browse files- modeling.py +12 -24
modeling.py
CHANGED
@@ -42,20 +42,12 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
42 |
**kwargs
|
43 |
):
|
44 |
# 1. Manually construct the input embeddings
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
if not special_token_mask.any():
|
49 |
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
50 |
-
else:
|
51 |
-
normal_token_mask = ~special_token_mask
|
52 |
-
normal_ids = input_ids.clone()
|
53 |
-
normal_ids[special_token_mask] = 0
|
54 |
-
normal_embeds = self.model.embed_tokens(normal_ids)
|
55 |
-
|
56 |
-
inputs_embeds = torch.empty_like(normal_embeds)
|
57 |
-
inputs_embeds[normal_token_mask] = normal_embeds[normal_token_mask]
|
58 |
|
|
|
|
|
|
|
59 |
special_ids = input_ids[special_token_mask] - self.original_vocab_size
|
60 |
special_embeds = self.new_token_embeddings(special_ids)
|
61 |
inputs_embeds[special_token_mask] = special_embeds
|
@@ -89,10 +81,10 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
89 |
|
90 |
# 5. Modify the token logits conditionally.
|
91 |
deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
|
92 |
-
additional_logits = torch.zeros_like(logits)
|
93 |
|
94 |
-
# Conditionally add the deletion logits
|
95 |
if hallucination_labels is not None and labels is not None:
|
|
|
96 |
# Condition 1: The hallucination label is 0 (no hallucination)
|
97 |
mask_no_hallucination = (hallucination_labels == 0)
|
98 |
|
@@ -100,22 +92,18 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
100 |
# Check if the token ID is within the range of the last `num_new_tokens` in the vocab
|
101 |
vocab_size = logits.shape[-1]
|
102 |
mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)
|
103 |
-
|
104 |
-
# Combine
|
105 |
-
# We need to align the shapes for broadcasting
|
106 |
combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
|
107 |
-
|
108 |
-
# Use the mask to conditionally apply the deletion logits
|
109 |
-
additional_logits[:, :, -self.num_new_tokens:] = torch.where(
|
110 |
combined_mask,
|
111 |
deletion_logits,
|
112 |
torch.zeros_like(deletion_logits)
|
113 |
)
|
|
|
114 |
else:
|
115 |
# Inference case: always add the deletion logits to the token logits
|
116 |
-
|
117 |
-
|
118 |
-
logits = logits + additional_logits
|
119 |
|
120 |
# 6. Return the custom output object
|
121 |
return SelfCorrectiveLlamaOutput(
|
|
|
42 |
**kwargs
|
43 |
):
|
44 |
# 1. Manually construct the input embeddings
|
45 |
+
clamped_input_ids = torch.clamp(input_ids, max=self.original_vocab_size - 1)
|
46 |
+
inputs_embeds = self.model.embed_tokens(clamped_input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# Overwrite the embeddings for our new special tokens
|
49 |
+
special_token_mask = input_ids >= self.original_vocab_size
|
50 |
+
if special_token_mask.any():
|
51 |
special_ids = input_ids[special_token_mask] - self.original_vocab_size
|
52 |
special_embeds = self.new_token_embeddings(special_ids)
|
53 |
inputs_embeds[special_token_mask] = special_embeds
|
|
|
81 |
|
82 |
# 5. Modify the token logits conditionally.
|
83 |
deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
|
|
|
84 |
|
85 |
+
# Conditionally add the deletion logits.
|
86 |
if hallucination_labels is not None and labels is not None:
|
87 |
+
# Training case:
|
88 |
# Condition 1: The hallucination label is 0 (no hallucination)
|
89 |
mask_no_hallucination = (hallucination_labels == 0)
|
90 |
|
|
|
92 |
# Check if the token ID is within the range of the last `num_new_tokens` in the vocab
|
93 |
vocab_size = logits.shape[-1]
|
94 |
mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)
|
95 |
+
|
96 |
+
# Combine masks and create the tensor to add.
|
|
|
97 |
combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
|
98 |
+
to_add = torch.where(
|
|
|
|
|
99 |
combined_mask,
|
100 |
deletion_logits,
|
101 |
torch.zeros_like(deletion_logits)
|
102 |
)
|
103 |
+
logits[:, :, -self.num_new_tokens:].add_(to_add)
|
104 |
else:
|
105 |
# Inference case: always add the deletion logits to the token logits
|
106 |
+
logits[:, :, -self.num_new_tokens:].add_(deletion_logits)
|
|
|
|
|
107 |
|
108 |
# 6. Return the custom output object
|
109 |
return SelfCorrectiveLlamaOutput(
|