MathBite commited on
Commit
1d7496d
·
verified ·
1 Parent(s): 6807253

updated forward pass to be more memeory efficient

Browse files
Files changed (1) hide show
  1. modeling.py +12 -24
modeling.py CHANGED
@@ -42,20 +42,12 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
42
  **kwargs
43
  ):
44
  # 1. Manually construct the input embeddings
45
- # This allows us to use a separate embedding layer for our new tokens, saving memory
46
- special_token_mask = input_ids >= self.original_vocab_size
47
-
48
- if not special_token_mask.any():
49
- inputs_embeds = self.model.embed_tokens(input_ids)
50
- else:
51
- normal_token_mask = ~special_token_mask
52
- normal_ids = input_ids.clone()
53
- normal_ids[special_token_mask] = 0
54
- normal_embeds = self.model.embed_tokens(normal_ids)
55
-
56
- inputs_embeds = torch.empty_like(normal_embeds)
57
- inputs_embeds[normal_token_mask] = normal_embeds[normal_token_mask]
58
 
 
 
 
59
  special_ids = input_ids[special_token_mask] - self.original_vocab_size
60
  special_embeds = self.new_token_embeddings(special_ids)
61
  inputs_embeds[special_token_mask] = special_embeds
@@ -89,10 +81,10 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
89
 
90
  # 5. Modify the token logits conditionally.
91
  deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
92
- additional_logits = torch.zeros_like(logits)
93
 
94
- # Conditionally add the deletion logits if we are in training and labels are provided
95
  if hallucination_labels is not None and labels is not None:
 
96
  # Condition 1: The hallucination label is 0 (no hallucination)
97
  mask_no_hallucination = (hallucination_labels == 0)
98
 
@@ -100,22 +92,18 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
100
  # Check if the token ID is within the range of the last `num_new_tokens` in the vocab
101
  vocab_size = logits.shape[-1]
102
  mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)
103
-
104
- # Combine the masks. The addition happens if either condition is true
105
- # We need to align the shapes for broadcasting
106
  combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
107
-
108
- # Use the mask to conditionally apply the deletion logits
109
- additional_logits[:, :, -self.num_new_tokens:] = torch.where(
110
  combined_mask,
111
  deletion_logits,
112
  torch.zeros_like(deletion_logits)
113
  )
 
114
  else:
115
  # Inference case: always add the deletion logits to the token logits
116
- additional_logits[:, :, -self.num_new_tokens:] = deletion_logits
117
-
118
- logits = logits + additional_logits
119
 
120
  # 6. Return the custom output object
121
  return SelfCorrectiveLlamaOutput(
 
42
  **kwargs
43
  ):
44
  # 1. Manually construct the input embeddings
45
+ clamped_input_ids = torch.clamp(input_ids, max=self.original_vocab_size - 1)
46
+ inputs_embeds = self.model.embed_tokens(clamped_input_ids)
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Overwrite the embeddings for our new special tokens
49
+ special_token_mask = input_ids >= self.original_vocab_size
50
+ if special_token_mask.any():
51
  special_ids = input_ids[special_token_mask] - self.original_vocab_size
52
  special_embeds = self.new_token_embeddings(special_ids)
53
  inputs_embeds[special_token_mask] = special_embeds
 
81
 
82
  # 5. Modify the token logits conditionally.
83
  deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
 
84
 
85
+ # Conditionally add the deletion logits.
86
  if hallucination_labels is not None and labels is not None:
87
+ # Training case:
88
  # Condition 1: The hallucination label is 0 (no hallucination)
89
  mask_no_hallucination = (hallucination_labels == 0)
90
 
 
92
  # Check if the token ID is within the range of the last `num_new_tokens` in the vocab
93
  vocab_size = logits.shape[-1]
94
  mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)
95
+
96
+ # Combine masks and create the tensor to add.
 
97
  combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
98
+ to_add = torch.where(
 
 
99
  combined_mask,
100
  deletion_logits,
101
  torch.zeros_like(deletion_logits)
102
  )
103
+ logits[:, :, -self.num_new_tokens:].add_(to_add)
104
  else:
105
  # Inference case: always add the deletion logits to the token logits
106
+ logits[:, :, -self.num_new_tokens:].add_(deletion_logits)
 
 
107
 
108
  # 6. Return the custom output object
109
  return SelfCorrectiveLlamaOutput(