Update implementation
Browse files- config.json +1 -1
- configuration_chatglm.py +6 -5
- generation_config.json +5 -0
- modeling_chatglm.py +185 -84
- quantization.py +1 -3
- save_model.py +4 -0
- tokenization_chatglm.py +31 -10
    	
        config.json
    CHANGED
    
    | @@ -35,7 +35,7 @@ | |
| 35 | 
             
              "seq_length": 8192,
         | 
| 36 | 
             
              "use_cache": true,
         | 
| 37 | 
             
              "torch_dtype": "bfloat16",
         | 
| 38 | 
            -
              "transformers_version": "4. | 
| 39 | 
             
              "tie_word_embeddings": false,
         | 
| 40 | 
             
              "eos_token_id": 2
         | 
| 41 | 
             
            }
         | 
|  | |
| 35 | 
             
              "seq_length": 8192,
         | 
| 36 | 
             
              "use_cache": true,
         | 
| 37 | 
             
              "torch_dtype": "bfloat16",
         | 
| 38 | 
            +
              "transformers_version": "4.30.2",
         | 
| 39 | 
             
              "tie_word_embeddings": false,
         | 
| 40 | 
             
              "eos_token_id": 2
         | 
| 41 | 
             
            }
         | 
    	
        configuration_chatglm.py
    CHANGED
    
    | @@ -20,18 +20,19 @@ class ChatGLMConfig(PretrainedConfig): | |
| 20 | 
             
                    post_layer_norm=True,
         | 
| 21 | 
             
                    add_bias_linear=False,
         | 
| 22 | 
             
                    add_qkv_bias=False,
         | 
| 23 | 
            -
                    interleaved_qkv=False,
         | 
| 24 | 
             
                    bias_dropout_fusion=True,
         | 
| 25 | 
            -
                    rotary_percent=1.0,
         | 
| 26 | 
             
                    multi_query_attention=False,
         | 
| 27 | 
             
                    multi_query_group_num=1,
         | 
| 28 | 
             
                    apply_query_key_layer_scaling=True,
         | 
| 29 | 
             
                    attention_softmax_in_fp32=True,
         | 
| 30 | 
             
                    fp32_residual_connection=False,
         | 
| 31 | 
             
                    quantization_bit=0,
         | 
|  | |
|  | |
| 32 | 
             
                    **kwargs
         | 
| 33 | 
             
                ):
         | 
| 34 | 
             
                    self.num_layers = num_layers
         | 
|  | |
| 35 | 
             
                    self.padded_vocab_size = padded_vocab_size
         | 
| 36 | 
             
                    self.hidden_size = hidden_size
         | 
| 37 | 
             
                    self.ffn_hidden_size = ffn_hidden_size
         | 
| @@ -46,13 +47,13 @@ class ChatGLMConfig(PretrainedConfig): | |
| 46 | 
             
                    self.post_layer_norm = post_layer_norm
         | 
| 47 | 
             
                    self.add_bias_linear = add_bias_linear
         | 
| 48 | 
             
                    self.add_qkv_bias = add_qkv_bias
         | 
| 49 | 
            -
                    self.interleaved_qkv = interleaved_qkv
         | 
| 50 | 
             
                    self.bias_dropout_fusion = bias_dropout_fusion
         | 
| 51 | 
            -
                    self.rotary_percent = rotary_percent
         | 
| 52 | 
             
                    self.multi_query_attention = multi_query_attention
         | 
| 53 | 
             
                    self.multi_query_group_num = multi_query_group_num
         | 
| 54 | 
             
                    self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
         | 
| 55 | 
             
                    self.attention_softmax_in_fp32 = attention_softmax_in_fp32
         | 
| 56 | 
             
                    self.fp32_residual_connection = fp32_residual_connection
         | 
| 57 | 
             
                    self.quantization_bit = quantization_bit
         | 
| 58 | 
            -
                     | 
|  | |
|  | 
|  | |
| 20 | 
             
                    post_layer_norm=True,
         | 
| 21 | 
             
                    add_bias_linear=False,
         | 
| 22 | 
             
                    add_qkv_bias=False,
         | 
|  | |
| 23 | 
             
                    bias_dropout_fusion=True,
         | 
|  | |
| 24 | 
             
                    multi_query_attention=False,
         | 
| 25 | 
             
                    multi_query_group_num=1,
         | 
| 26 | 
             
                    apply_query_key_layer_scaling=True,
         | 
| 27 | 
             
                    attention_softmax_in_fp32=True,
         | 
| 28 | 
             
                    fp32_residual_connection=False,
         | 
| 29 | 
             
                    quantization_bit=0,
         | 
| 30 | 
            +
                    pre_seq_len=None,
         | 
| 31 | 
            +
                    prefix_projection=False,
         | 
| 32 | 
             
                    **kwargs
         | 
| 33 | 
             
                ):
         | 
| 34 | 
             
                    self.num_layers = num_layers
         | 
| 35 | 
            +
                    self.vocab_size = padded_vocab_size
         | 
| 36 | 
             
                    self.padded_vocab_size = padded_vocab_size
         | 
| 37 | 
             
                    self.hidden_size = hidden_size
         | 
| 38 | 
             
                    self.ffn_hidden_size = ffn_hidden_size
         | 
|  | |
| 47 | 
             
                    self.post_layer_norm = post_layer_norm
         | 
| 48 | 
             
                    self.add_bias_linear = add_bias_linear
         | 
| 49 | 
             
                    self.add_qkv_bias = add_qkv_bias
         | 
|  | |
| 50 | 
             
                    self.bias_dropout_fusion = bias_dropout_fusion
         | 
|  | |
| 51 | 
             
                    self.multi_query_attention = multi_query_attention
         | 
| 52 | 
             
                    self.multi_query_group_num = multi_query_group_num
         | 
| 53 | 
             
                    self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
         | 
| 54 | 
             
                    self.attention_softmax_in_fp32 = attention_softmax_in_fp32
         | 
| 55 | 
             
                    self.fp32_residual_connection = fp32_residual_connection
         | 
| 56 | 
             
                    self.quantization_bit = quantization_bit
         | 
| 57 | 
            +
                    self.pre_seq_len = pre_seq_len
         | 
| 58 | 
            +
                    self.prefix_projection = prefix_projection
         | 
| 59 | 
            +
                    super().__init__(**kwargs)
         | 
    	
        generation_config.json
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_from_model_config": true,
         | 
| 3 | 
            +
              "eos_token_id": 2,
         | 
| 4 | 
            +
              "transformers_version": "4.30.2"
         | 
| 5 | 
            +
            }
         | 
    	
        modeling_chatglm.py
    CHANGED
    
    | @@ -35,12 +35,12 @@ if sys.platform != 'darwin': | |
| 35 |  | 
| 36 | 
             
            logger = logging.get_logger(__name__)
         | 
| 37 |  | 
| 38 | 
            -
            _CHECKPOINT_FOR_DOC = "THUDM/ | 
| 39 | 
             
            _CONFIG_FOR_DOC = "ChatGLM6BConfig"
         | 
| 40 |  | 
| 41 | 
             
            CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
         | 
| 42 | 
            -
                "THUDM/ | 
| 43 | 
            -
                # See all ChatGLM | 
| 44 | 
             
            ]
         | 
| 45 |  | 
| 46 |  | 
| @@ -56,6 +56,38 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): | |
| 56 | 
             
                    return scores
         | 
| 57 |  | 
| 58 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 59 | 
             
            def split_tensor_along_last_dim(
         | 
| 60 | 
             
                    tensor: torch.Tensor,
         | 
| 61 | 
             
                    num_partitions: int,
         | 
| @@ -87,12 +119,12 @@ def split_tensor_along_last_dim( | |
| 87 | 
             
            class RotaryEmbedding(nn.Module):
         | 
| 88 | 
             
                def __init__(self, dim, original_impl=False, device=None, dtype=None):
         | 
| 89 | 
             
                    super().__init__()
         | 
| 90 | 
            -
                    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device | 
| 91 | 
             
                    self.register_buffer("inv_freq", inv_freq)
         | 
| 92 | 
             
                    self.dim = dim
         | 
| 93 | 
             
                    self.original_impl = original_impl
         | 
| 94 |  | 
| 95 | 
            -
                def  | 
| 96 | 
             
                        self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
         | 
| 97 | 
             
                ):
         | 
| 98 | 
             
                    """Enhanced Transformer with Rotary Position Embedding.
         | 
| @@ -118,14 +150,13 @@ class RotaryEmbedding(nn.Module): | |
| 118 | 
             
                    return cache
         | 
| 119 |  | 
| 120 | 
             
                def forward(self, max_seq_len, offset=0):
         | 
| 121 | 
            -
                     | 
| 122 | 
            -
                         | 
| 123 | 
            -
             | 
| 124 | 
            -
                        )
         | 
| 125 |  | 
| 126 |  | 
| 127 | 
             
            @torch.jit.script
         | 
| 128 | 
            -
            def  | 
| 129 | 
             
                # x: [sq, b, np, hn]
         | 
| 130 | 
             
                sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
         | 
| 131 | 
             
                rot_dim = rope_cache.shape[-2] * 2
         | 
| @@ -151,10 +182,12 @@ class RMSNorm(torch.nn.Module): | |
| 151 | 
             
                    self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
         | 
| 152 | 
             
                    self.eps = eps
         | 
| 153 |  | 
| 154 | 
            -
                def forward(self,  | 
| 155 | 
            -
                     | 
| 156 | 
            -
                     | 
| 157 | 
            -
                     | 
|  | |
|  | |
| 158 |  | 
| 159 |  | 
| 160 | 
             
            class CoreAttention(torch.nn.Module):
         | 
| @@ -311,8 +344,6 @@ class SelfAttention(torch.nn.Module): | |
| 311 | 
             
                                           device=device, **_config_to_kwargs(config)
         | 
| 312 | 
             
                                           )
         | 
| 313 |  | 
| 314 | 
            -
                    self.interleaved_qkv = config.interleaved_qkv
         | 
| 315 | 
            -
             | 
| 316 | 
             
                def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
         | 
| 317 | 
             
                    if self.multi_query_attention:
         | 
| 318 | 
             
                        num_attention_heads = self.num_multi_query_groups_per_partition
         | 
| @@ -362,40 +393,25 @@ class SelfAttention(torch.nn.Module): | |
| 362 | 
             
                            + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
         | 
| 363 | 
             
                        )
         | 
| 364 | 
             
                    else:
         | 
| 365 | 
            -
                         | 
| 366 | 
            -
             | 
| 367 | 
            -
             | 
| 368 | 
            -
             | 
| 369 | 
            -
                            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
         | 
| 370 |  | 
| 371 | 
             
                        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
         | 
| 372 | 
             
                        (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
         | 
| 373 |  | 
| 374 | 
            -
                        if not self.interleaved_qkv:
         | 
| 375 | 
            -
                            query_layer = query_layer.view(
         | 
| 376 | 
            -
                                query_layer.size()[:-1] + (
         | 
| 377 | 
            -
                                    self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
         | 
| 378 | 
            -
                            ).contiguous()
         | 
| 379 | 
            -
                            key_layer = key_layer.view(
         | 
| 380 | 
            -
                                key_layer.size()[:-1] + (
         | 
| 381 | 
            -
                                    self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
         | 
| 382 | 
            -
                            ).contiguous()
         | 
| 383 | 
            -
                            value_layer = value_layer.view(
         | 
| 384 | 
            -
                                value_layer.size()[:-1] + (
         | 
| 385 | 
            -
                                    self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
         | 
| 386 | 
            -
                            ).contiguous()
         | 
| 387 | 
            -
             | 
| 388 | 
             
                    # apply relative positional encoding (rotary embedding)
         | 
| 389 | 
             
                    if rotary_pos_emb is not None:
         | 
| 390 | 
            -
                        query_layer =  | 
| 391 | 
            -
                        key_layer =  | 
| 392 |  | 
| 393 | 
             
                    # adjust key and value for inference
         | 
|  | |
|  | |
|  | |
|  | |
| 394 | 
             
                    if use_cache:
         | 
| 395 | 
            -
                        if kv_cache is not None:
         | 
| 396 | 
            -
                            cache_k, cache_v = kv_cache
         | 
| 397 | 
            -
                            key_layer = torch.cat((cache_k, key_layer), dim=0)
         | 
| 398 | 
            -
                            value_layer = torch.cat((cache_v, value_layer), dim=0)
         | 
| 399 | 
             
                        kv_cache = (key_layer, value_layer)
         | 
| 400 | 
             
                    else:
         | 
| 401 | 
             
                        kv_cache = None
         | 
| @@ -582,6 +598,8 @@ class GLMTransformer(torch.nn.Module): | |
| 582 | 
             
                        self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
         | 
| 583 | 
             
                                                             dtype=config.torch_dtype)
         | 
| 584 |  | 
|  | |
|  | |
| 585 | 
             
                def _get_layer(self, layer_number):
         | 
| 586 | 
             
                    return self.layers[layer_number]
         | 
| 587 |  | 
| @@ -593,6 +611,13 @@ class GLMTransformer(torch.nn.Module): | |
| 593 | 
             
                    if not kv_caches:
         | 
| 594 | 
             
                        kv_caches = [None for _ in range(self.num_layers)]
         | 
| 595 | 
             
                    presents = () if use_cache else None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 596 | 
             
                    all_self_attentions = None
         | 
| 597 | 
             
                    all_hidden_states = () if output_hidden_states else None
         | 
| 598 | 
             
                    for index in range(self.num_layers):
         | 
| @@ -600,14 +625,24 @@ class GLMTransformer(torch.nn.Module): | |
| 600 | 
             
                            all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 601 |  | 
| 602 | 
             
                        layer = self._get_layer(index)
         | 
| 603 | 
            -
             | 
| 604 | 
            -
             | 
| 605 | 
            -
             | 
| 606 | 
            -
             | 
| 607 | 
            -
             | 
| 608 | 
            -
             | 
| 609 | 
            -
             | 
| 610 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 611 | 
             
                        if use_cache:
         | 
| 612 | 
             
                            presents = presents + (kv_cache,)
         | 
| 613 |  | 
| @@ -661,7 +696,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel): | |
| 661 | 
             
                    return position_ids
         | 
| 662 |  | 
| 663 | 
             
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 664 | 
            -
                    if isinstance(module,  | 
| 665 | 
             
                        module.gradient_checkpointing = value
         | 
| 666 |  | 
| 667 |  | 
| @@ -704,6 +739,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel): | |
| 704 | 
             
                    if device is not None:
         | 
| 705 | 
             
                        init_kwargs["device"] = device
         | 
| 706 | 
             
                    self.embedding = init_method(Embedding, config, **init_kwargs)
         | 
|  | |
|  | |
|  | |
| 707 |  | 
| 708 | 
             
                    # Rotary positional embeddings
         | 
| 709 | 
             
                    self.seq_length = config.seq_length
         | 
| @@ -711,18 +749,37 @@ class ChatGLMModel(ChatGLMPreTrainedModel): | |
| 711 | 
             
                        config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
         | 
| 712 | 
             
                    )
         | 
| 713 |  | 
| 714 | 
            -
                     | 
| 715 | 
            -
                        rotary_dim = int(rotary_dim * config.rotary_percent)
         | 
| 716 | 
            -
             | 
| 717 | 
            -
                    # partial rotary embeddings, which is better than full rotary
         | 
| 718 | 
            -
                    # Wang and Komatsuzaki et al
         | 
| 719 | 
            -
                    # https://github.com/kingoflolz/mesh-transformer-jax/
         | 
| 720 | 
            -
                    self.rotary_pos_emb = RotaryEmbedding(rotary_dim, original_impl=config.original_rope, device=device,
         | 
| 721 | 
             
                                                          dtype=config.torch_dtype)
         | 
| 722 | 
             
                    self.encoder = init_method(GLMTransformer, config, **init_kwargs)
         | 
| 723 | 
             
                    self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
         | 
| 724 | 
             
                                                    dtype=config.torch_dtype, **init_kwargs)
         | 
| 725 | 
            -
                    self. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 726 |  | 
| 727 | 
             
                def forward(
         | 
| 728 | 
             
                        self,
         | 
| @@ -747,8 +804,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel): | |
| 747 | 
             
                    if inputs_embeds is None:
         | 
| 748 | 
             
                        inputs_embeds = self.embedding(input_ids)
         | 
| 749 |  | 
| 750 | 
            -
                    if  | 
| 751 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 752 |  | 
| 753 | 
             
                    # Rotary positional embeddings
         | 
| 754 | 
             
                    rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
         | 
| @@ -820,6 +886,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 820 | 
             
                            [position_ids, new_position_id], dim=-1
         | 
| 821 | 
             
                        )
         | 
| 822 |  | 
|  | |
| 823 | 
             
                    return model_kwargs
         | 
| 824 |  | 
| 825 | 
             
                def prepare_inputs_for_generation(
         | 
| @@ -828,20 +895,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 828 | 
             
                        past_key_values: Optional[torch.Tensor] = None,
         | 
| 829 | 
             
                        attention_mask: Optional[torch.Tensor] = None,
         | 
| 830 | 
             
                        position_ids: Optional[torch.Tensor] = None,
         | 
| 831 | 
            -
                         | 
| 832 | 
             
                        **kwargs
         | 
| 833 | 
             
                ) -> dict:
         | 
| 834 | 
             
                    # only last token for input_ids if past is not None
         | 
| 835 | 
            -
                    if  | 
| 836 | 
            -
                         | 
| 837 | 
            -
             | 
| 838 | 
             
                        position_ids = position_ids[..., -1:]
         | 
| 839 | 
             
                        input_ids = input_ids[:, -1:]
         | 
| 840 | 
             
                    return {
         | 
| 841 | 
             
                        "input_ids": input_ids,
         | 
| 842 | 
             
                        "past_key_values": past_key_values,
         | 
| 843 | 
             
                        "position_ids": position_ids,
         | 
| 844 | 
            -
                        "attention_mask": attention_mask
         | 
|  | |
| 845 | 
             
                    }
         | 
| 846 |  | 
| 847 | 
             
                def forward(
         | 
| @@ -856,6 +924,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 856 | 
             
                        output_attentions: Optional[bool] = None,
         | 
| 857 | 
             
                        output_hidden_states: Optional[bool] = None,
         | 
| 858 | 
             
                        return_dict: Optional[bool] = None,
         | 
|  | |
| 859 | 
             
                ):
         | 
| 860 | 
             
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 861 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| @@ -872,7 +941,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 872 | 
             
                    )
         | 
| 873 |  | 
| 874 | 
             
                    hidden_states = transformer_outputs[0]
         | 
| 875 | 
            -
             | 
|  | |
| 876 | 
             
                    lm_logits = self.transformer.output_layer(hidden_states)
         | 
| 877 | 
             
                    lm_logits = lm_logits.transpose(0, 1).contiguous()
         | 
| 878 |  | 
| @@ -927,16 +997,25 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 927 | 
             
                    return response
         | 
| 928 |  | 
| 929 | 
             
                def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
         | 
| 930 | 
            -
                    prompt =  | 
| 931 | 
            -
                    for i, (old_query, response) in enumerate(history):
         | 
| 932 | 
            -
                        prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
         | 
| 933 | 
            -
                    prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
         | 
| 934 | 
             
                    inputs = tokenizer([prompt], return_tensors="pt")
         | 
| 935 | 
             
                    inputs = inputs.to(self.device)
         | 
| 936 | 
             
                    return inputs
         | 
| 937 |  | 
| 938 | 
            -
                 | 
| 939 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 940 | 
             
                         do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
         | 
| 941 | 
             
                    if history is None:
         | 
| 942 | 
             
                        history = []
         | 
| @@ -953,9 +1032,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 953 | 
             
                    history = history + [(query, response)]
         | 
| 954 | 
             
                    return response, history
         | 
| 955 |  | 
| 956 | 
            -
                @torch. | 
| 957 | 
            -
                def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,  | 
| 958 | 
            -
                                do_sample=True, top_p=0. | 
|  | |
| 959 | 
             
                    if history is None:
         | 
| 960 | 
             
                        history = []
         | 
| 961 | 
             
                    if logits_processor is None:
         | 
| @@ -963,15 +1043,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 963 | 
             
                    logits_processor.append(InvalidScoreLogitsProcessor())
         | 
| 964 | 
             
                    gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
         | 
| 965 | 
             
                                  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
         | 
| 966 | 
            -
                     | 
| 967 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 968 | 
             
                        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
         | 
| 969 | 
             
                        response = tokenizer.decode(outputs)
         | 
| 970 | 
            -
                        response  | 
| 971 | 
            -
             | 
| 972 | 
            -
             | 
| 973 | 
            -
             | 
| 974 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 975 | 
             
                def stream_generate(
         | 
| 976 | 
             
                        self,
         | 
| 977 | 
             
                        input_ids,
         | 
| @@ -979,6 +1077,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 979 | 
             
                        logits_processor: Optional[LogitsProcessorList] = None,
         | 
| 980 | 
             
                        stopping_criteria: Optional[StoppingCriteriaList] = None,
         | 
| 981 | 
             
                        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
         | 
|  | |
| 982 | 
             
                        **kwargs,
         | 
| 983 | 
             
                ):
         | 
| 984 | 
             
                    batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
         | 
| @@ -1067,11 +1166,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 1067 | 
             
                            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
         | 
| 1068 | 
             
                        )
         | 
| 1069 | 
             
                        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
         | 
| 1070 | 
            -
             | 
|  | |
|  | |
|  | |
| 1071 | 
             
                        # stop when each sentence is finished, or if we exceed the maximum length
         | 
| 1072 | 
             
                        if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
         | 
| 1073 | 
             
                            break
         | 
| 1074 | 
            -
                        yield input_ids
         | 
| 1075 |  | 
| 1076 | 
             
                def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
         | 
| 1077 | 
             
                    if bits == 0:
         | 
|  | |
| 35 |  | 
| 36 | 
             
            logger = logging.get_logger(__name__)
         | 
| 37 |  | 
| 38 | 
            +
            _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
         | 
| 39 | 
             
            _CONFIG_FOR_DOC = "ChatGLM6BConfig"
         | 
| 40 |  | 
| 41 | 
             
            CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
         | 
| 42 | 
            +
                "THUDM/chatglm2-6b",
         | 
| 43 | 
            +
                # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
         | 
| 44 | 
             
            ]
         | 
| 45 |  | 
| 46 |  | 
|  | |
| 56 | 
             
                    return scores
         | 
| 57 |  | 
| 58 |  | 
| 59 | 
            +
            class PrefixEncoder(torch.nn.Module):
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                The torch.nn model to encode the prefix
         | 
| 62 | 
            +
                Input shape: (batch-size, prefix-length)
         | 
| 63 | 
            +
                Output shape: (batch-size, prefix-length, 2*layers*hidden)
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def __init__(self, config: ChatGLMConfig):
         | 
| 67 | 
            +
                    super().__init__()
         | 
| 68 | 
            +
                    self.prefix_projection = config.prefix_projection
         | 
| 69 | 
            +
                    if self.prefix_projection:
         | 
| 70 | 
            +
                        # Use a two-layer MLP to encode the prefix
         | 
| 71 | 
            +
                        kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
         | 
| 72 | 
            +
                        self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
         | 
| 73 | 
            +
                        self.trans = torch.nn.Sequential(
         | 
| 74 | 
            +
                            torch.nn.Linear(kv_size, config.hidden_size),
         | 
| 75 | 
            +
                            torch.nn.Tanh(),
         | 
| 76 | 
            +
                            torch.nn.Linear(config.hidden_size, kv_size)
         | 
| 77 | 
            +
                        )
         | 
| 78 | 
            +
                    else:
         | 
| 79 | 
            +
                        self.embedding = torch.nn.Embedding(config.pre_seq_len,
         | 
| 80 | 
            +
                                                            config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def forward(self, prefix: torch.Tensor):
         | 
| 83 | 
            +
                    if self.prefix_projection:
         | 
| 84 | 
            +
                        prefix_tokens = self.embedding(prefix)
         | 
| 85 | 
            +
                        past_key_values = self.trans(prefix_tokens)
         | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        past_key_values = self.embedding(prefix)
         | 
| 88 | 
            +
                    return past_key_values
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
             
            def split_tensor_along_last_dim(
         | 
| 92 | 
             
                    tensor: torch.Tensor,
         | 
| 93 | 
             
                    num_partitions: int,
         | 
|  | |
| 119 | 
             
            class RotaryEmbedding(nn.Module):
         | 
| 120 | 
             
                def __init__(self, dim, original_impl=False, device=None, dtype=None):
         | 
| 121 | 
             
                    super().__init__()
         | 
| 122 | 
            +
                    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
         | 
| 123 | 
             
                    self.register_buffer("inv_freq", inv_freq)
         | 
| 124 | 
             
                    self.dim = dim
         | 
| 125 | 
             
                    self.original_impl = original_impl
         | 
| 126 |  | 
| 127 | 
            +
                def forward_impl(
         | 
| 128 | 
             
                        self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
         | 
| 129 | 
             
                ):
         | 
| 130 | 
             
                    """Enhanced Transformer with Rotary Position Embedding.
         | 
|  | |
| 150 | 
             
                    return cache
         | 
| 151 |  | 
| 152 | 
             
                def forward(self, max_seq_len, offset=0):
         | 
| 153 | 
            +
                    return self.forward_impl(
         | 
| 154 | 
            +
                        max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
         | 
| 155 | 
            +
                    )
         | 
|  | |
| 156 |  | 
| 157 |  | 
| 158 | 
             
            @torch.jit.script
         | 
| 159 | 
            +
            def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
         | 
| 160 | 
             
                # x: [sq, b, np, hn]
         | 
| 161 | 
             
                sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
         | 
| 162 | 
             
                rot_dim = rope_cache.shape[-2] * 2
         | 
|  | |
| 182 | 
             
                    self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
         | 
| 183 | 
             
                    self.eps = eps
         | 
| 184 |  | 
| 185 | 
            +
                def forward(self, hidden_states: torch.Tensor):
         | 
| 186 | 
            +
                    input_dtype = hidden_states.dtype
         | 
| 187 | 
            +
                    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         | 
| 188 | 
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    return (self.weight * hidden_states).to(input_dtype)
         | 
| 191 |  | 
| 192 |  | 
| 193 | 
             
            class CoreAttention(torch.nn.Module):
         | 
|  | |
| 344 | 
             
                                           device=device, **_config_to_kwargs(config)
         | 
| 345 | 
             
                                           )
         | 
| 346 |  | 
|  | |
|  | |
| 347 | 
             
                def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
         | 
| 348 | 
             
                    if self.multi_query_attention:
         | 
| 349 | 
             
                        num_attention_heads = self.num_multi_query_groups_per_partition
         | 
|  | |
| 393 | 
             
                            + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
         | 
| 394 | 
             
                        )
         | 
| 395 | 
             
                    else:
         | 
| 396 | 
            +
                        new_tensor_shape = mixed_x_layer.size()[:-1] + \
         | 
| 397 | 
            +
                                           (self.num_attention_heads_per_partition,
         | 
| 398 | 
            +
                                            3 * self.hidden_size_per_attention_head)
         | 
| 399 | 
            +
                        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
         | 
|  | |
| 400 |  | 
| 401 | 
             
                        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
         | 
| 402 | 
             
                        (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
         | 
| 403 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 404 | 
             
                    # apply relative positional encoding (rotary embedding)
         | 
| 405 | 
             
                    if rotary_pos_emb is not None:
         | 
| 406 | 
            +
                        query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
         | 
| 407 | 
            +
                        key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
         | 
| 408 |  | 
| 409 | 
             
                    # adjust key and value for inference
         | 
| 410 | 
            +
                    if kv_cache is not None:
         | 
| 411 | 
            +
                        cache_k, cache_v = kv_cache
         | 
| 412 | 
            +
                        key_layer = torch.cat((cache_k, key_layer), dim=0)
         | 
| 413 | 
            +
                        value_layer = torch.cat((cache_v, value_layer), dim=0)
         | 
| 414 | 
             
                    if use_cache:
         | 
|  | |
|  | |
|  | |
|  | |
| 415 | 
             
                        kv_cache = (key_layer, value_layer)
         | 
| 416 | 
             
                    else:
         | 
| 417 | 
             
                        kv_cache = None
         | 
|  | |
| 598 | 
             
                        self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
         | 
| 599 | 
             
                                                             dtype=config.torch_dtype)
         | 
| 600 |  | 
| 601 | 
            +
                    self.gradient_checkpointing = False
         | 
| 602 | 
            +
             | 
| 603 | 
             
                def _get_layer(self, layer_number):
         | 
| 604 | 
             
                    return self.layers[layer_number]
         | 
| 605 |  | 
|  | |
| 611 | 
             
                    if not kv_caches:
         | 
| 612 | 
             
                        kv_caches = [None for _ in range(self.num_layers)]
         | 
| 613 | 
             
                    presents = () if use_cache else None
         | 
| 614 | 
            +
                    if self.gradient_checkpointing and self.training:
         | 
| 615 | 
            +
                        if use_cache:
         | 
| 616 | 
            +
                            logger.warning_once(
         | 
| 617 | 
            +
                                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
         | 
| 618 | 
            +
                            )
         | 
| 619 | 
            +
                            use_cache = False
         | 
| 620 | 
            +
             | 
| 621 | 
             
                    all_self_attentions = None
         | 
| 622 | 
             
                    all_hidden_states = () if output_hidden_states else None
         | 
| 623 | 
             
                    for index in range(self.num_layers):
         | 
|  | |
| 625 | 
             
                            all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 626 |  | 
| 627 | 
             
                        layer = self._get_layer(index)
         | 
| 628 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 629 | 
            +
                            layer_ret = torch.utils.checkpoint.checkpoint(
         | 
| 630 | 
            +
                                layer,
         | 
| 631 | 
            +
                                hidden_states,
         | 
| 632 | 
            +
                                attention_mask,
         | 
| 633 | 
            +
                                rotary_pos_emb,
         | 
| 634 | 
            +
                                kv_caches[index],
         | 
| 635 | 
            +
                                use_cache
         | 
| 636 | 
            +
                            )
         | 
| 637 | 
            +
                        else:
         | 
| 638 | 
            +
                            layer_ret = layer(
         | 
| 639 | 
            +
                                hidden_states,
         | 
| 640 | 
            +
                                attention_mask,
         | 
| 641 | 
            +
                                rotary_pos_emb,
         | 
| 642 | 
            +
                                kv_cache=kv_caches[index],
         | 
| 643 | 
            +
                                use_cache=use_cache
         | 
| 644 | 
            +
                            )
         | 
| 645 | 
            +
                        hidden_states, kv_cache = layer_ret
         | 
| 646 | 
             
                        if use_cache:
         | 
| 647 | 
             
                            presents = presents + (kv_cache,)
         | 
| 648 |  | 
|  | |
| 696 | 
             
                    return position_ids
         | 
| 697 |  | 
| 698 | 
             
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 699 | 
            +
                    if isinstance(module, GLMTransformer):
         | 
| 700 | 
             
                        module.gradient_checkpointing = value
         | 
| 701 |  | 
| 702 |  | 
|  | |
| 739 | 
             
                    if device is not None:
         | 
| 740 | 
             
                        init_kwargs["device"] = device
         | 
| 741 | 
             
                    self.embedding = init_method(Embedding, config, **init_kwargs)
         | 
| 742 | 
            +
                    self.num_layers = config.num_layers
         | 
| 743 | 
            +
                    self.multi_query_group_num = config.multi_query_group_num
         | 
| 744 | 
            +
                    self.kv_channels = config.kv_channels
         | 
| 745 |  | 
| 746 | 
             
                    # Rotary positional embeddings
         | 
| 747 | 
             
                    self.seq_length = config.seq_length
         | 
|  | |
| 749 | 
             
                        config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
         | 
| 750 | 
             
                    )
         | 
| 751 |  | 
| 752 | 
            +
                    self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 753 | 
             
                                                          dtype=config.torch_dtype)
         | 
| 754 | 
             
                    self.encoder = init_method(GLMTransformer, config, **init_kwargs)
         | 
| 755 | 
             
                    self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
         | 
| 756 | 
             
                                                    dtype=config.torch_dtype, **init_kwargs)
         | 
| 757 | 
            +
                    self.pre_seq_len = config.pre_seq_len
         | 
| 758 | 
            +
                    self.prefix_projection = config.prefix_projection
         | 
| 759 | 
            +
                    if self.pre_seq_len is not None:
         | 
| 760 | 
            +
                        for param in self.parameters():
         | 
| 761 | 
            +
                            param.requires_grad = False
         | 
| 762 | 
            +
                        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
         | 
| 763 | 
            +
                        self.prefix_encoder = PrefixEncoder(config)
         | 
| 764 | 
            +
                        self.dropout = torch.nn.Dropout(0.1)
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                def get_input_embeddings(self):
         | 
| 767 | 
            +
                    return self.embedding.word_embeddings
         | 
| 768 | 
            +
             | 
| 769 | 
            +
                def get_prompt(self, batch_size, device, dtype=torch.half):
         | 
| 770 | 
            +
                    prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
         | 
| 771 | 
            +
                    past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
         | 
| 772 | 
            +
                    past_key_values = past_key_values.view(
         | 
| 773 | 
            +
                        batch_size,
         | 
| 774 | 
            +
                        self.pre_seq_len,
         | 
| 775 | 
            +
                        self.num_layers * 2,
         | 
| 776 | 
            +
                        self.multi_query_group_num,
         | 
| 777 | 
            +
                        self.kv_channels
         | 
| 778 | 
            +
                    )
         | 
| 779 | 
            +
                    # seq_len, b, nh, hidden_size
         | 
| 780 | 
            +
                    past_key_values = self.dropout(past_key_values)
         | 
| 781 | 
            +
                    past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
         | 
| 782 | 
            +
                    return past_key_values
         | 
| 783 |  | 
| 784 | 
             
                def forward(
         | 
| 785 | 
             
                        self,
         | 
|  | |
| 804 | 
             
                    if inputs_embeds is None:
         | 
| 805 | 
             
                        inputs_embeds = self.embedding(input_ids)
         | 
| 806 |  | 
| 807 | 
            +
                    if self.pre_seq_len is not None:
         | 
| 808 | 
            +
                        if past_key_values is None:
         | 
| 809 | 
            +
                            past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
         | 
| 810 | 
            +
                                                              dtype=inputs_embeds.dtype)
         | 
| 811 | 
            +
                        if attention_mask is not None:
         | 
| 812 | 
            +
                            attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
         | 
| 813 | 
            +
                                                        attention_mask], dim=-1)
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                    if full_attention_mask is None:
         | 
| 816 | 
            +
                        if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
         | 
| 817 | 
            +
                            full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
         | 
| 818 |  | 
| 819 | 
             
                    # Rotary positional embeddings
         | 
| 820 | 
             
                    rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
         | 
|  | |
| 886 | 
             
                            [position_ids, new_position_id], dim=-1
         | 
| 887 | 
             
                        )
         | 
| 888 |  | 
| 889 | 
            +
                    model_kwargs["is_first_forward"] = False
         | 
| 890 | 
             
                    return model_kwargs
         | 
| 891 |  | 
| 892 | 
             
                def prepare_inputs_for_generation(
         | 
|  | |
| 895 | 
             
                        past_key_values: Optional[torch.Tensor] = None,
         | 
| 896 | 
             
                        attention_mask: Optional[torch.Tensor] = None,
         | 
| 897 | 
             
                        position_ids: Optional[torch.Tensor] = None,
         | 
| 898 | 
            +
                        is_first_forward: bool = True,
         | 
| 899 | 
             
                        **kwargs
         | 
| 900 | 
             
                ) -> dict:
         | 
| 901 | 
             
                    # only last token for input_ids if past is not None
         | 
| 902 | 
            +
                    if position_ids is None:
         | 
| 903 | 
            +
                        position_ids = self.get_position_ids(input_ids, device=input_ids.device)
         | 
| 904 | 
            +
                    if not is_first_forward:
         | 
| 905 | 
             
                        position_ids = position_ids[..., -1:]
         | 
| 906 | 
             
                        input_ids = input_ids[:, -1:]
         | 
| 907 | 
             
                    return {
         | 
| 908 | 
             
                        "input_ids": input_ids,
         | 
| 909 | 
             
                        "past_key_values": past_key_values,
         | 
| 910 | 
             
                        "position_ids": position_ids,
         | 
| 911 | 
            +
                        "attention_mask": attention_mask,
         | 
| 912 | 
            +
                        "return_last_logit": True
         | 
| 913 | 
             
                    }
         | 
| 914 |  | 
| 915 | 
             
                def forward(
         | 
|  | |
| 924 | 
             
                        output_attentions: Optional[bool] = None,
         | 
| 925 | 
             
                        output_hidden_states: Optional[bool] = None,
         | 
| 926 | 
             
                        return_dict: Optional[bool] = None,
         | 
| 927 | 
            +
                        return_last_logit: Optional[bool] = False,
         | 
| 928 | 
             
                ):
         | 
| 929 | 
             
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 930 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
|  | |
| 941 | 
             
                    )
         | 
| 942 |  | 
| 943 | 
             
                    hidden_states = transformer_outputs[0]
         | 
| 944 | 
            +
                    if return_last_logit:
         | 
| 945 | 
            +
                        hidden_states = hidden_states[-1:]
         | 
| 946 | 
             
                    lm_logits = self.transformer.output_layer(hidden_states)
         | 
| 947 | 
             
                    lm_logits = lm_logits.transpose(0, 1).contiguous()
         | 
| 948 |  | 
|  | |
| 997 | 
             
                    return response
         | 
| 998 |  | 
| 999 | 
             
                def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
         | 
| 1000 | 
            +
                    prompt = tokenizer.build_prompt(query, history=history)
         | 
|  | |
|  | |
|  | |
| 1001 | 
             
                    inputs = tokenizer([prompt], return_tensors="pt")
         | 
| 1002 | 
             
                    inputs = inputs.to(self.device)
         | 
| 1003 | 
             
                    return inputs
         | 
| 1004 |  | 
| 1005 | 
            +
                def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
         | 
| 1006 | 
            +
                    if history:
         | 
| 1007 | 
            +
                        prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
         | 
| 1008 | 
            +
                        input_ids = tokenizer.encode(prompt, add_special_tokens=False)
         | 
| 1009 | 
            +
                        input_ids = input_ids[1:]
         | 
| 1010 | 
            +
                        inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
         | 
| 1011 | 
            +
                    else:
         | 
| 1012 | 
            +
                        prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
         | 
| 1013 | 
            +
                        inputs = tokenizer([prompt], return_tensors="pt")
         | 
| 1014 | 
            +
                    inputs = inputs.to(self.device)
         | 
| 1015 | 
            +
                    return inputs
         | 
| 1016 | 
            +
             | 
| 1017 | 
            +
                @torch.inference_mode()
         | 
| 1018 | 
            +
                def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
         | 
| 1019 | 
             
                         do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
         | 
| 1020 | 
             
                    if history is None:
         | 
| 1021 | 
             
                        history = []
         | 
|  | |
| 1032 | 
             
                    history = history + [(query, response)]
         | 
| 1033 | 
             
                    return response, history
         | 
| 1034 |  | 
| 1035 | 
            +
                @torch.inference_mode()
         | 
| 1036 | 
            +
                def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
         | 
| 1037 | 
            +
                                max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
         | 
| 1038 | 
            +
                                return_past_key_values=False, **kwargs):
         | 
| 1039 | 
             
                    if history is None:
         | 
| 1040 | 
             
                        history = []
         | 
| 1041 | 
             
                    if logits_processor is None:
         | 
|  | |
| 1043 | 
             
                    logits_processor.append(InvalidScoreLogitsProcessor())
         | 
| 1044 | 
             
                    gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
         | 
| 1045 | 
             
                                  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
         | 
| 1046 | 
            +
                    if past_key_values is None and not return_past_key_values:
         | 
| 1047 | 
            +
                        inputs = self.build_inputs(tokenizer, query, history=history)
         | 
| 1048 | 
            +
                    else:
         | 
| 1049 | 
            +
                        inputs = self.build_stream_inputs(tokenizer, query, history=history)
         | 
| 1050 | 
            +
                    if past_key_values is not None:
         | 
| 1051 | 
            +
                        past_length = past_key_values[0][0].shape[0]
         | 
| 1052 | 
            +
                        if self.transformer.pre_seq_len is not None:
         | 
| 1053 | 
            +
                            past_length -= self.transformer.pre_seq_len
         | 
| 1054 | 
            +
                        inputs.position_ids += past_length
         | 
| 1055 | 
            +
                        attention_mask = inputs.attention_mask
         | 
| 1056 | 
            +
                        attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
         | 
| 1057 | 
            +
                        inputs['attention_mask'] = attention_mask
         | 
| 1058 | 
            +
                    for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
         | 
| 1059 | 
            +
                                                        return_past_key_values=return_past_key_values, **gen_kwargs):
         | 
| 1060 | 
            +
                        if return_past_key_values:
         | 
| 1061 | 
            +
                            outputs, past_key_values = outputs
         | 
| 1062 | 
             
                        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
         | 
| 1063 | 
             
                        response = tokenizer.decode(outputs)
         | 
| 1064 | 
            +
                        if response and response[-1] != "�":
         | 
| 1065 | 
            +
                            response = self.process_response(response)
         | 
| 1066 | 
            +
                            new_history = history + [(query, response)]
         | 
| 1067 | 
            +
                            if return_past_key_values:
         | 
| 1068 | 
            +
                                yield response, new_history, past_key_values
         | 
| 1069 | 
            +
                            else:
         | 
| 1070 | 
            +
                                yield response, new_history
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
                @torch.inference_mode()
         | 
| 1073 | 
             
                def stream_generate(
         | 
| 1074 | 
             
                        self,
         | 
| 1075 | 
             
                        input_ids,
         | 
|  | |
| 1077 | 
             
                        logits_processor: Optional[LogitsProcessorList] = None,
         | 
| 1078 | 
             
                        stopping_criteria: Optional[StoppingCriteriaList] = None,
         | 
| 1079 | 
             
                        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
         | 
| 1080 | 
            +
                        return_past_key_values=False,
         | 
| 1081 | 
             
                        **kwargs,
         | 
| 1082 | 
             
                ):
         | 
| 1083 | 
             
                    batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
         | 
|  | |
| 1166 | 
             
                            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
         | 
| 1167 | 
             
                        )
         | 
| 1168 | 
             
                        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
         | 
| 1169 | 
            +
                        if return_past_key_values:
         | 
| 1170 | 
            +
                            yield input_ids, outputs.past_key_values
         | 
| 1171 | 
            +
                        else:
         | 
| 1172 | 
            +
                            yield input_ids
         | 
| 1173 | 
             
                        # stop when each sentence is finished, or if we exceed the maximum length
         | 
| 1174 | 
             
                        if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
         | 
| 1175 | 
             
                            break
         | 
|  | |
| 1176 |  | 
| 1177 | 
             
                def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
         | 
| 1178 | 
             
                    if bits == 0:
         | 
    	
        quantization.py
    CHANGED
    
    | @@ -24,7 +24,7 @@ try: | |
| 24 | 
             
                        for name in self._function_names:
         | 
| 25 | 
             
                            setattr(self, name, KernelFunction(self._cmodule, name))
         | 
| 26 |  | 
| 27 | 
            -
                quantization_code = " | 
| 28 |  | 
| 29 | 
             
                kernels = Kernel(
         | 
| 30 | 
             
                    bz2.decompress(base64.b64decode(quantization_code)),
         | 
| @@ -32,10 +32,8 @@ try: | |
| 32 | 
             
                        "int4WeightCompression",
         | 
| 33 | 
             
                        "int4WeightExtractionFloat",
         | 
| 34 | 
             
                        "int4WeightExtractionHalf",
         | 
| 35 | 
            -
                        "int4WeightExtractionBFloat16",
         | 
| 36 | 
             
                        "int8WeightExtractionFloat",
         | 
| 37 | 
             
                        "int8WeightExtractionHalf",
         | 
| 38 | 
            -
                        "int8WeightExtractionBFloat16",
         | 
| 39 | 
             
                    ],
         | 
| 40 | 
             
                )
         | 
| 41 | 
             
            except Exception as exception:
         | 
|  | |
| 24 | 
             
                        for name in self._function_names:
         | 
| 25 | 
             
                            setattr(self, name, KernelFunction(self._cmodule, name))
         | 
| 26 |  | 
| 27 | 
            +
                quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
         | 
| 28 |  | 
| 29 | 
             
                kernels = Kernel(
         | 
| 30 | 
             
                    bz2.decompress(base64.b64decode(quantization_code)),
         | 
|  | |
| 32 | 
             
                        "int4WeightCompression",
         | 
| 33 | 
             
                        "int4WeightExtractionFloat",
         | 
| 34 | 
             
                        "int4WeightExtractionHalf",
         | 
|  | |
| 35 | 
             
                        "int8WeightExtractionFloat",
         | 
| 36 | 
             
                        "int8WeightExtractionHalf",
         | 
|  | |
| 37 | 
             
                    ],
         | 
| 38 | 
             
                )
         | 
| 39 | 
             
            except Exception as exception:
         | 
    	
        save_model.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import AutoModel
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            model = AutoModel.from_pretrained("/mnt/vepfs/qinkai/release/codegeex2-6b/", trust_remote_code=True).cuda()
         | 
| 4 | 
            +
            model.save_pretrained("./", max_shard_size="2000MB")
         | 
    	
        tokenization_chatglm.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ class SPTokenizer: | |
| 17 | 
             
                    self.n_words: int = self.sp_model.vocab_size()
         | 
| 18 | 
             
                    self.bos_id: int = self.sp_model.bos_id()
         | 
| 19 | 
             
                    self.eos_id: int = self.sp_model.eos_id()
         | 
| 20 | 
            -
                    self.pad_id: int = self.sp_model. | 
| 21 | 
             
                    assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
         | 
| 22 |  | 
| 23 | 
             
                    special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
         | 
| @@ -55,7 +55,7 @@ class SPTokenizer: | |
| 55 |  | 
| 56 | 
             
                def convert_id_to_token(self, index):
         | 
| 57 | 
             
                    """Converts an index (integer) in a token (str) using the vocab."""
         | 
| 58 | 
            -
                    if index in self.index_special_tokens:
         | 
| 59 | 
             
                        return ""
         | 
| 60 | 
             
                    return self.sp_model.IdToPiece(index)
         | 
| 61 |  | 
| @@ -65,10 +65,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer): | |
| 65 |  | 
| 66 | 
             
                model_input_names = ["input_ids", "attention_mask", "position_ids"]
         | 
| 67 |  | 
| 68 | 
            -
                def __init__(self, vocab_file, padding_side="left", **kwargs):
         | 
| 69 | 
            -
                    super().__init__(padding_side=padding_side, clean_up_tokenization_spaces= | 
| 70 | 
             
                    self.name = "GLMTokenizer"
         | 
| 71 |  | 
|  | |
| 72 | 
             
                    self.tokenizer = SPTokenizer(vocab_file)
         | 
| 73 | 
             
                    self.special_tokens = {
         | 
| 74 | 
             
                        "<bos>": self.tokenizer.bos_id,
         | 
| @@ -82,14 +83,26 @@ class ChatGLMTokenizer(PreTrainedTokenizer): | |
| 82 | 
             
                    assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
         | 
| 83 | 
             
                    return self.tokenizer.special_tokens[token]
         | 
| 84 |  | 
|  | |
|  | |
|  | |
|  | |
| 85 | 
             
                @property
         | 
| 86 | 
             
                def pad_token(self) -> str:
         | 
| 87 | 
            -
                    return " | 
| 88 |  | 
| 89 | 
             
                @property
         | 
| 90 | 
             
                def pad_token_id(self):
         | 
| 91 | 
             
                    return self.get_command("<pad>")
         | 
| 92 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 93 | 
             
                @property
         | 
| 94 | 
             
                def vocab_size(self):
         | 
| 95 | 
             
                    return self.tokenizer.n_words
         | 
| @@ -146,6 +159,15 @@ class ChatGLMTokenizer(PreTrainedTokenizer): | |
| 146 | 
             
                    prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
         | 
| 147 | 
             
                    return prefix_tokens
         | 
| 148 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 149 | 
             
                def build_inputs_with_special_tokens(
         | 
| 150 | 
             
                        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
         | 
| 151 | 
             
                ) -> List[int]:
         | 
| @@ -217,12 +239,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer): | |
| 217 | 
             
                    needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
         | 
| 218 |  | 
| 219 | 
             
                    # Initialize attention mask if not present.
         | 
| 220 | 
            -
                    if  | 
| 221 | 
            -
                         | 
| 222 | 
            -
                            encoded_inputs["attention_mask"] = [1] * seq_length
         | 
| 223 |  | 
| 224 | 
            -
             | 
| 225 | 
            -
             | 
| 226 |  | 
| 227 | 
             
                    if needs_to_be_padded:
         | 
| 228 | 
             
                        difference = max_length - len(required_input)
         | 
|  | |
| 17 | 
             
                    self.n_words: int = self.sp_model.vocab_size()
         | 
| 18 | 
             
                    self.bos_id: int = self.sp_model.bos_id()
         | 
| 19 | 
             
                    self.eos_id: int = self.sp_model.eos_id()
         | 
| 20 | 
            +
                    self.pad_id: int = self.sp_model.unk_id()
         | 
| 21 | 
             
                    assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
         | 
| 22 |  | 
| 23 | 
             
                    special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
         | 
|  | |
| 55 |  | 
| 56 | 
             
                def convert_id_to_token(self, index):
         | 
| 57 | 
             
                    """Converts an index (integer) in a token (str) using the vocab."""
         | 
| 58 | 
            +
                    if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
         | 
| 59 | 
             
                        return ""
         | 
| 60 | 
             
                    return self.sp_model.IdToPiece(index)
         | 
| 61 |  | 
|  | |
| 65 |  | 
| 66 | 
             
                model_input_names = ["input_ids", "attention_mask", "position_ids"]
         | 
| 67 |  | 
| 68 | 
            +
                def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
         | 
| 69 | 
            +
                    super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
         | 
| 70 | 
             
                    self.name = "GLMTokenizer"
         | 
| 71 |  | 
| 72 | 
            +
                    self.vocab_file = vocab_file
         | 
| 73 | 
             
                    self.tokenizer = SPTokenizer(vocab_file)
         | 
| 74 | 
             
                    self.special_tokens = {
         | 
| 75 | 
             
                        "<bos>": self.tokenizer.bos_id,
         | 
|  | |
| 83 | 
             
                    assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
         | 
| 84 | 
             
                    return self.tokenizer.special_tokens[token]
         | 
| 85 |  | 
| 86 | 
            +
                @property
         | 
| 87 | 
            +
                def unk_token(self) -> str:
         | 
| 88 | 
            +
                    return "<unk>"
         | 
| 89 | 
            +
             | 
| 90 | 
             
                @property
         | 
| 91 | 
             
                def pad_token(self) -> str:
         | 
| 92 | 
            +
                    return "<unk>"
         | 
| 93 |  | 
| 94 | 
             
                @property
         | 
| 95 | 
             
                def pad_token_id(self):
         | 
| 96 | 
             
                    return self.get_command("<pad>")
         | 
| 97 |  | 
| 98 | 
            +
                @property
         | 
| 99 | 
            +
                def eos_token(self) -> str:
         | 
| 100 | 
            +
                    return "</s>"
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                @property
         | 
| 103 | 
            +
                def eos_token_id(self):
         | 
| 104 | 
            +
                    return self.get_command("<eos>")
         | 
| 105 | 
            +
             | 
| 106 | 
             
                @property
         | 
| 107 | 
             
                def vocab_size(self):
         | 
| 108 | 
             
                    return self.tokenizer.n_words
         | 
|  | |
| 159 | 
             
                    prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
         | 
| 160 | 
             
                    return prefix_tokens
         | 
| 161 |  | 
| 162 | 
            +
                def build_prompt(self, query, history=None):
         | 
| 163 | 
            +
                    if history is None:
         | 
| 164 | 
            +
                        history = []
         | 
| 165 | 
            +
                    prompt = ""
         | 
| 166 | 
            +
                    for i, (old_query, response) in enumerate(history):
         | 
| 167 | 
            +
                        prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
         | 
| 168 | 
            +
                    prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
         | 
| 169 | 
            +
                    return prompt
         | 
| 170 | 
            +
             | 
| 171 | 
             
                def build_inputs_with_special_tokens(
         | 
| 172 | 
             
                        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
         | 
| 173 | 
             
                ) -> List[int]:
         | 
|  | |
| 239 | 
             
                    needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
         | 
| 240 |  | 
| 241 | 
             
                    # Initialize attention mask if not present.
         | 
| 242 | 
            +
                    if "attention_mask" not in encoded_inputs:
         | 
| 243 | 
            +
                        encoded_inputs["attention_mask"] = [1] * seq_length
         | 
|  | |
| 244 |  | 
| 245 | 
            +
                    if "position_ids" not in encoded_inputs:
         | 
| 246 | 
            +
                        encoded_inputs["position_ids"] = list(range(seq_length))
         | 
| 247 |  | 
| 248 | 
             
                    if needs_to_be_padded:
         | 
| 249 | 
             
                        difference = max_length - len(required_input)
         | 

