reference the flash attention GitHub
Browse files- bert_padding.py +5 -0
 - block.py +5 -0
 - embedding.py +5 -0
 - mha.py +9 -0
 - mlp.py +5 -0
 
    	
        bert_padding.py
    CHANGED
    
    | 
         @@ -1,5 +1,10 @@ 
     | 
|
| 1 | 
         
             
            # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
         
     | 
| 2 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import torch.nn.functional as F
         
     | 
| 5 | 
         
             
            from einops import rearrange, repeat
         
     | 
| 
         | 
|
| 1 | 
         
             
            # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
         
     | 
| 2 | 
         | 
| 3 | 
         
            +
            """"
         
     | 
| 4 | 
         
            +
            The implementation was further adapted from
         
     | 
| 5 | 
         
            +
            https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
         
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
             
            import torch
         
     | 
| 9 | 
         
             
            import torch.nn.functional as F
         
     | 
| 10 | 
         
             
            from einops import rearrange, repeat
         
     | 
    	
        block.py
    CHANGED
    
    | 
         @@ -1,5 +1,10 @@ 
     | 
|
| 1 | 
         
             
            # Copyright (c) 2024, Tri Dao.
         
     | 
| 2 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            from functools import partial
         
     | 
| 4 | 
         
             
            from typing import Optional
         
     | 
| 5 | 
         | 
| 
         | 
|
| 1 | 
         
             
            # Copyright (c) 2024, Tri Dao.
         
     | 
| 2 | 
         | 
| 3 | 
         
            +
            """"
         
     | 
| 4 | 
         
            +
            The implementation was adopted from
         
     | 
| 5 | 
         
            +
            https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
         
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
             
            from functools import partial
         
     | 
| 9 | 
         
             
            from typing import Optional
         
     | 
| 10 | 
         | 
    	
        embedding.py
    CHANGED
    
    | 
         @@ -1,5 +1,10 @@ 
     | 
|
| 1 | 
         
             
            # Copyright (c) 2022, Tri Dao.
         
     | 
| 2 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import torch.nn as nn
         
     | 
| 5 | 
         
             
            from torch import Tensor
         
     | 
| 
         | 
|
| 1 | 
         
             
            # Copyright (c) 2022, Tri Dao.
         
     | 
| 2 | 
         | 
| 3 | 
         
            +
            """"
         
     | 
| 4 | 
         
            +
            The implementation was adopted from
         
     | 
| 5 | 
         
            +
            https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
         
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
             
            import torch
         
     | 
| 9 | 
         
             
            import torch.nn as nn
         
     | 
| 10 | 
         
             
            from torch import Tensor
         
     | 
    	
        mha.py
    CHANGED
    
    | 
         @@ -1,5 +1,14 @@ 
     | 
|
| 1 | 
         
             
            # Copyright (c) 2023, Tri Dao.
         
     | 
| 2 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            import math
         
     | 
| 4 | 
         
             
            from functools import partial
         
     | 
| 5 | 
         | 
| 
         | 
|
| 1 | 
         
             
            # Copyright (c) 2023, Tri Dao.
         
     | 
| 2 | 
         | 
| 3 | 
         
            +
            """"
         
     | 
| 4 | 
         
            +
            The implementation was adopted from
         
     | 
| 5 | 
         
            +
            https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
         
     | 
| 6 | 
         
            +
            and made modifications to
         
     | 
| 7 | 
         
            +
                - support QK normalization
         
     | 
| 8 | 
         
            +
                - make ALiBi run with MHA (needed to cast alibi slopes to fp32)
         
     | 
| 9 | 
         
            +
                - make ALiBi run on CPU
         
     | 
| 10 | 
         
            +
            """
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
             
            import math
         
     | 
| 13 | 
         
             
            from functools import partial
         
     | 
| 14 | 
         | 
    	
        mlp.py
    CHANGED
    
    | 
         @@ -1,5 +1,10 @@ 
     | 
|
| 1 | 
         
             
            # Copyright (c) 2023, Tri Dao.
         
     | 
| 2 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import torch.nn as nn
         
     | 
| 5 | 
         
             
            import torch.nn.functional as F
         
     | 
| 
         | 
|
| 1 | 
         
             
            # Copyright (c) 2023, Tri Dao.
         
     | 
| 2 | 
         | 
| 3 | 
         
            +
            """"
         
     | 
| 4 | 
         
            +
            The implementation was adopted from
         
     | 
| 5 | 
         
            +
            https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
         
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
             
            import torch
         
     | 
| 9 | 
         
             
            import torch.nn as nn
         
     | 
| 10 | 
         
             
            import torch.nn.functional as F
         
     |