Update modeling_internlm2_ve.py (#6)
Browse files- Update modeling_internlm2_ve.py (40e5dc42cdbd98aafc20738caadd02a89eb63e3e)
Co-authored-by: Gen Luo <[email protected]>
- modeling_internlm2_ve.py +11 -13
modeling_internlm2_ve.py
CHANGED
|
@@ -689,20 +689,18 @@ class InternLM2DecoderLayer(nn.Module):
|
|
| 689 |
hidden_states = self.ffn_norm(hidden_states)
|
| 690 |
|
| 691 |
if past_key_value is None:
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
# hidden_states[~visual_token_mask] = self.feed_forward(hidden_states[~visual_token_mask].reshape(-1,dim)).reshape(-1)
|
| 704 |
##############################################################################################################
|
| 705 |
-
hidden_states = self.feed_forward(hidden_states)*(1.-visual_token_mask)+ self.feed_forward_ve(hidden_states)*visual_token_mask
|
| 706 |
else:
|
| 707 |
hidden_states = self.feed_forward(hidden_states)
|
| 708 |
|
|
|
|
| 689 |
hidden_states = self.ffn_norm(hidden_states)
|
| 690 |
|
| 691 |
if past_key_value is None:
|
| 692 |
+
##########################################--modified by luogen--##############################################
|
| 693 |
+
if self.training:
|
| 694 |
+
hidden_states = self.feed_forward(hidden_states)*(1.-visual_token_mask)+ self.feed_forward_ve(hidden_states)*visual_token_mask
|
| 695 |
+
else:
|
| 696 |
+
dim=hidden_states.shape[-1]
|
| 697 |
+
visual_token_mask=visual_token_mask.repeat(1,1,dim).bool()
|
| 698 |
+
non_visual_token_mask=~visual_token_mask
|
| 699 |
+
if visual_token_mask.any():
|
| 700 |
+
hidden_states[visual_token_mask] = self.feed_forward_ve(hidden_states[visual_token_mask].reshape(-1,dim)).reshape(-1)
|
| 701 |
+
if (non_visual_token_mask).any():
|
| 702 |
+
hidden_states[non_visual_token_mask] = self.feed_forward(hidden_states[non_visual_token_mask].reshape(-1,dim)).reshape(-1)
|
|
|
|
| 703 |
##############################################################################################################
|
|
|
|
| 704 |
else:
|
| 705 |
hidden_states = self.feed_forward(hidden_states)
|
| 706 |
|