Fix precision error
Browse files- modeling_chatglm.py +27 -8
    	
        modeling_chatglm.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ import copy | |
| 5 | 
             
            import warnings
         | 
| 6 | 
             
            import re
         | 
| 7 | 
             
            import sys
         | 
| 8 | 
            -
             | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import torch.utils.checkpoint
         | 
| 11 | 
             
            import torch.nn.functional as F
         | 
| @@ -177,15 +177,21 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten | |
| 177 |  | 
| 178 |  | 
| 179 | 
             
            class RMSNorm(torch.nn.Module):
         | 
| 180 | 
            -
                def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
         | 
| 181 | 
             
                    super().__init__()
         | 
| 182 | 
             
                    self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
         | 
| 183 | 
             
                    self.eps = eps
         | 
|  | |
| 184 |  | 
| 185 | 
             
                def forward(self, hidden_states: torch.Tensor):
         | 
| 186 | 
            -
                     | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 189 |  | 
| 190 | 
             
                    return (self.weight * hidden_states).to(input_dtype)
         | 
| 191 |  | 
| @@ -515,10 +521,17 @@ class GLMBlock(torch.nn.Module): | |
| 515 |  | 
| 516 | 
             
                    self.fp32_residual_connection = config.fp32_residual_connection
         | 
| 517 |  | 
| 518 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 519 | 
             
                    # Layernorm on the input data.
         | 
| 520 | 
             
                    self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
         | 
| 521 | 
            -
             | 
| 522 |  | 
| 523 | 
             
                    # Self attention.
         | 
| 524 | 
             
                    self.self_attention = SelfAttention(config, layer_number, device=device)
         | 
| @@ -593,7 +606,13 @@ class GLMTransformer(torch.nn.Module): | |
| 593 | 
             
                    self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
         | 
| 594 |  | 
| 595 | 
             
                    if self.post_layer_norm:
         | 
| 596 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 597 | 
             
                        # Final layer norm before output.
         | 
| 598 | 
             
                        self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
         | 
| 599 | 
             
                                                             dtype=config.torch_dtype)
         | 
|  | |
| 5 | 
             
            import warnings
         | 
| 6 | 
             
            import re
         | 
| 7 | 
             
            import sys
         | 
| 8 | 
            +
            import functools
         | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import torch.utils.checkpoint
         | 
| 11 | 
             
            import torch.nn.functional as F
         | 
|  | |
| 177 |  | 
| 178 |  | 
| 179 | 
             
            class RMSNorm(torch.nn.Module):
         | 
| 180 | 
            +
                def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, quantized=False, **kwargs):
         | 
| 181 | 
             
                    super().__init__()
         | 
| 182 | 
             
                    self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
         | 
| 183 | 
             
                    self.eps = eps
         | 
| 184 | 
            +
                    self.quantized = quantized
         | 
| 185 |  | 
| 186 | 
             
                def forward(self, hidden_states: torch.Tensor):
         | 
| 187 | 
            +
                    if not self.quantized:
         | 
| 188 | 
            +
                        norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
         | 
| 189 | 
            +
                        x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
         | 
| 190 | 
            +
                        return self.weight * x_normed
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        input_dtype = hidden_states.dtype
         | 
| 193 | 
            +
                        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         | 
| 194 | 
            +
                        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
         | 
| 195 |  | 
| 196 | 
             
                    return (self.weight * hidden_states).to(input_dtype)
         | 
| 197 |  | 
|  | |
| 521 |  | 
| 522 | 
             
                    self.fp32_residual_connection = config.fp32_residual_connection
         | 
| 523 |  | 
| 524 | 
            +
                    if config.rmsnorm:
         | 
| 525 | 
            +
                        if config.quantization_bit != 0:
         | 
| 526 | 
            +
                            LayerNormFunc = functools.partial(RMSNorm, quantized=True)
         | 
| 527 | 
            +
                        else:
         | 
| 528 | 
            +
                            LayerNormFunc = RMSNorm
         | 
| 529 | 
            +
                    else:
         | 
| 530 | 
            +
                        LayerNormFunc = LayerNorm
         | 
| 531 | 
            +
                    
         | 
| 532 | 
             
                    # Layernorm on the input data.
         | 
| 533 | 
             
                    self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
         | 
| 534 | 
            +
                                                             dtype=config.torch_dtype)
         | 
| 535 |  | 
| 536 | 
             
                    # Self attention.
         | 
| 537 | 
             
                    self.self_attention = SelfAttention(config, layer_number, device=device)
         | 
|  | |
| 606 | 
             
                    self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
         | 
| 607 |  | 
| 608 | 
             
                    if self.post_layer_norm:
         | 
| 609 | 
            +
                        if config.rmsnorm:
         | 
| 610 | 
            +
                            if config.quantization_bit != 0:
         | 
| 611 | 
            +
                                LayerNormFunc = functools.partial(RMSNorm, quantized=True)
         | 
| 612 | 
            +
                            else:
         | 
| 613 | 
            +
                                LayerNormFunc = RMSNorm
         | 
| 614 | 
            +
                        else:
         | 
| 615 | 
            +
                            LayerNormFunc = LayerNorm
         | 
| 616 | 
             
                        # Final layer norm before output.
         | 
| 617 | 
             
                        self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
         | 
| 618 | 
             
                                                             dtype=config.torch_dtype)
         | 

