duzx16
commited on
Commit
·
8127ab6
1
Parent(s):
fbda120
Support batch training
Browse files- modeling_chatglm.py +26 -23
modeling_chatglm.py
CHANGED
|
@@ -818,33 +818,37 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 818 |
return past_key_values
|
| 819 |
|
| 820 |
@staticmethod
|
| 821 |
-
def get_masks(self,
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
attention_mask = torch.ones((
|
| 825 |
attention_mask.tril_()
|
| 826 |
-
|
|
|
|
| 827 |
attention_mask.unsqueeze_(1)
|
| 828 |
attention_mask = (attention_mask < 0.5).bool()
|
| 829 |
|
| 830 |
return attention_mask
|
| 831 |
|
| 832 |
-
def get_position_ids(self,
|
| 833 |
-
|
|
|
|
| 834 |
if self.position_encoding_2d:
|
| 835 |
-
|
| 836 |
-
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
| 837 |
if not gmask:
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
torch.
|
| 842 |
-
|
| 843 |
-
|
|
|
|
|
|
|
| 844 |
else:
|
| 845 |
-
position_ids = torch.arange(
|
| 846 |
if not gmask:
|
| 847 |
-
|
|
|
|
| 848 |
|
| 849 |
position_ids = position_ids.unsqueeze(0)
|
| 850 |
|
|
@@ -890,16 +894,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 890 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
| 891 |
else:
|
| 892 |
past_key_values = tuple([None] * len(self.layers))
|
| 893 |
-
seq = input_ids[0].tolist()
|
| 894 |
|
| 895 |
if attention_mask is None:
|
| 896 |
attention_mask = self.get_masks(
|
| 897 |
-
|
| 898 |
device=input_ids.device
|
| 899 |
)
|
| 900 |
|
| 901 |
if self.pre_seq_len is not None:
|
| 902 |
-
prefix_attention_mask = torch.ones(1, 1,
|
| 903 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 904 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 905 |
|
|
@@ -908,10 +911,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 908 |
mask_token = MASK if MASK in input_ids else gMASK
|
| 909 |
use_gmask = False if MASK in input_ids else gMASK
|
| 910 |
|
| 911 |
-
|
| 912 |
position_ids = self.get_position_ids(
|
| 913 |
-
|
| 914 |
-
|
| 915 |
device=input_ids.device,
|
| 916 |
gmask=use_gmask
|
| 917 |
)
|
|
|
|
| 818 |
return past_key_values
|
| 819 |
|
| 820 |
@staticmethod
|
| 821 |
+
def get_masks(self, input_ids, device):
|
| 822 |
+
batch_size, seq_length = input_ids.shape
|
| 823 |
+
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
| 824 |
+
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
| 825 |
attention_mask.tril_()
|
| 826 |
+
for i, context_length in enumerate(context_lengths):
|
| 827 |
+
attention_mask[i, :, :context_length] = 1
|
| 828 |
attention_mask.unsqueeze_(1)
|
| 829 |
attention_mask = (attention_mask < 0.5).bool()
|
| 830 |
|
| 831 |
return attention_mask
|
| 832 |
|
| 833 |
+
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
| 834 |
+
batch_size, seq_length = input_ids.shape
|
| 835 |
+
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
| 836 |
if self.position_encoding_2d:
|
| 837 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
|
|
|
| 838 |
if not gmask:
|
| 839 |
+
for i, context_length in enumerate(context_lengths):
|
| 840 |
+
position_ids[i, context_length:] = mask_positions[i]
|
| 841 |
+
block_position_ids = [torch.cat((
|
| 842 |
+
torch.zeros(context_length, dtype=torch.long, device=device),
|
| 843 |
+
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
| 844 |
+
)) for context_length in context_lengths]
|
| 845 |
+
block_position_ids = torch.stack(block_position_ids, dim=0)
|
| 846 |
+
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 847 |
else:
|
| 848 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
| 849 |
if not gmask:
|
| 850 |
+
for i, context_length in enumerate(context_lengths):
|
| 851 |
+
position_ids[context_length:] = mask_positions[i]
|
| 852 |
|
| 853 |
position_ids = position_ids.unsqueeze(0)
|
| 854 |
|
|
|
|
| 894 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
| 895 |
else:
|
| 896 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
|
| 897 |
|
| 898 |
if attention_mask is None:
|
| 899 |
attention_mask = self.get_masks(
|
| 900 |
+
input_ids,
|
| 901 |
device=input_ids.device
|
| 902 |
)
|
| 903 |
|
| 904 |
if self.pre_seq_len is not None:
|
| 905 |
+
prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
|
| 906 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 907 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 908 |
|
|
|
|
| 911 |
mask_token = MASK if MASK in input_ids else gMASK
|
| 912 |
use_gmask = False if MASK in input_ids else gMASK
|
| 913 |
|
| 914 |
+
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
| 915 |
position_ids = self.get_position_ids(
|
| 916 |
+
input_ids,
|
| 917 |
+
mask_positions=mask_positions,
|
| 918 |
device=input_ids.device,
|
| 919 |
gmask=use_gmask
|
| 920 |
)
|