Commit 
							
							·
						
						d11c6aa
	
0
								Parent(s):
							
							
Add chatglm-6b
Browse files- .gitattributes +34 -0
 - LICENSE +201 -0
 - MODEL_LICENSE +33 -0
 - README.md +81 -0
 - config.json +25 -0
 - configuration_chatglm.py +92 -0
 - ice_text.model +3 -0
 - modeling_chatglm.py +1152 -0
 - pytorch_model-00001-of-00008.bin +1 -0
 - pytorch_model-00002-of-00008.bin +1 -0
 - pytorch_model-00003-of-00008.bin +1 -0
 - pytorch_model-00004-of-00008.bin +1 -0
 - pytorch_model-00005-of-00008.bin +1 -0
 - pytorch_model-00006-of-00008.bin +1 -0
 - pytorch_model-00007-of-00008.bin +1 -0
 - pytorch_model-00008-of-00008.bin +1 -0
 - pytorch_model.bin.index.json +375 -0
 - quantization.py +187 -0
 - tokenization_chatglm.py +347 -0
 - tokenizer_config.json +19 -0
 
    	
        .gitattributes
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 4 | 
         
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 5 | 
         
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 6 | 
         
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 7 | 
         
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 8 | 
         
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 9 | 
         
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         
     | 
| 10 | 
         
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 11 | 
         
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         
     | 
| 12 | 
         
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         
     | 
| 13 | 
         
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         
     | 
| 14 | 
         
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 15 | 
         
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 16 | 
         
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         
     | 
| 17 | 
         
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         
     | 
| 18 | 
         
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         
     | 
| 19 | 
         
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         
     | 
| 20 | 
         
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         
     | 
| 21 | 
         
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 22 | 
         
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 23 | 
         
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         
     | 
| 24 | 
         
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 25 | 
         
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         
     | 
| 26 | 
         
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 27 | 
         
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 28 | 
         
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         
     | 
| 29 | 
         
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 30 | 
         
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         
     | 
| 31 | 
         
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 32 | 
         
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 33 | 
         
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
                                             Apache License
         
     | 
| 2 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 3 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               1. Definitions.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 10 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 13 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 16 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 17 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 18 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 19 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 20 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 21 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 24 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 27 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 28 | 
         
            +
                  source, and configuration files.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 31 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 32 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 33 | 
         
            +
                  and conversions to other media types.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 36 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 37 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 38 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 41 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 42 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 43 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 44 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 45 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 46 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 49 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 50 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 51 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 52 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 53 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 54 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 55 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 56 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 57 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 58 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 59 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 60 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 63 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 64 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 67 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 68 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 69 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 70 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 71 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 74 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 75 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 76 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 77 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 78 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 79 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 80 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 81 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 82 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 83 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 84 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 85 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 86 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 87 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 90 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 91 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 92 | 
         
            +
                  meet the following conditions:
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 95 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 98 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 101 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 102 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 103 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 104 | 
         
            +
                      the Derivative Works; and
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 107 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 108 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 109 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 110 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 111 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 112 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 113 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 114 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 115 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 116 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 117 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 118 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 119 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 120 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 121 | 
         
            +
                      as modifying the License.
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 124 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 125 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 126 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 127 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 128 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 131 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 132 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 133 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 134 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 135 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 136 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 139 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 140 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 141 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 144 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 145 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 146 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 147 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 148 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 149 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 150 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 151 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 154 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 155 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 156 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 157 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 158 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 159 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 160 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 161 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 162 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 163 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 166 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 167 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 168 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 169 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 170 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 171 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 172 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 173 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 174 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 181 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 182 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 183 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 184 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 185 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 186 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 187 | 
         
            +
                  identification within third-party archives.
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
               Copyright Zhengxiao Du
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 192 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 193 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 198 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 199 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 200 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 201 | 
         
            +
               limitations under the License.
         
     | 
    	
        MODEL_LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            The GLM-130B License
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            1. Definitions
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            “Licensor” means the GLM-130B Model Team that distributes its Software.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            “Software” means the GLM-130B model parameters made available under this license.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            2. License Grant
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            3. Restriction
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            4. Disclaimer
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            5. Limitation of Liability
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            6. Dispute Resolution
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            Note that the license is subject to update to a more comprehensive version.  For any questions related to the license and copyright, please contact us at [email protected].
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,81 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            language:
         
     | 
| 3 | 
         
            +
            - zh
         
     | 
| 4 | 
         
            +
            - en
         
     | 
| 5 | 
         
            +
            tags:
         
     | 
| 6 | 
         
            +
            - glm
         
     | 
| 7 | 
         
            +
            - chatglm
         
     | 
| 8 | 
         
            +
            - thudm
         
     | 
| 9 | 
         
            +
            ---
         
     | 
| 10 | 
         
            +
            # ChatGLM-6B
         
     | 
| 11 | 
         
            +
            ## 介绍
         
     | 
| 12 | 
         
            +
            ChatGLM-6B 是一个开源的、支持中英双语问答和对话的预训练语言模型,基于 [GLM](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。ChatGLM-6B 使用了和 ChatGLM(内测中,地址 [https://chatglm.cn](https://chatglm.cn))相同的技术面向中文问答和对话进行优化。
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            ## 使用方式
         
     | 
| 15 | 
         
            +
            使用前请先安装`transformers>=4.23.1`和`icetk`。
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            ```shell
         
     | 
| 18 | 
         
            +
            pip install "transformers>=4.23.1,icetk"
         
     | 
| 19 | 
         
            +
            ```
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            ### 代码调用 
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            可以通过如下代码调用 ChatGLM-6B 模型来生成对话。
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            ```python
         
     | 
| 26 | 
         
            +
            from transformers import AutoTokenizer, AutoModel
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
         
     | 
| 29 | 
         
            +
            model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
         
     | 
| 30 | 
         
            +
            model = model.eval()
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            history = []
         
     | 
| 33 | 
         
            +
            query = "你好"
         
     | 
| 34 | 
         
            +
            response, history = model.chat(tokenizer, query, history=history)
         
     | 
| 35 | 
         
            +
            print(response)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            query = "晚上睡不着应该怎么办"
         
     | 
| 38 | 
         
            +
            response, history = model.chat(tokenizer, query, history=history)
         
     | 
| 39 | 
         
            +
            print(history)
         
     | 
| 40 | 
         
            +
            ```
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            关于更多的使用说明,以及如何运行命令行和Web版本的demo,请参考我们的[Github repo](https://github.com/THUDM/ChatGLM-6B)。
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            ## INT8 量化
         
     | 
| 45 | 
         
            +
            默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试使用 `transformers` 提供的 8bit 量化功能,即将代码中的
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            ```python
         
     | 
| 48 | 
         
            +
            model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
         
     | 
| 49 | 
         
            +
            ```
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            替换为
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            ```python
         
     | 
| 54 | 
         
            +
            model = AutoModel.from_pretrained("THUDM/chatglm-6b", device_map="auto", load_in_8bit=True, trust_remote_code=True)
         
     | 
| 55 | 
         
            +
            ```
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            使用 8-bit 量化之后大约需要 9.5GB 的 GPU 显存。
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            ## 引用
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            如果你觉得我们的工作有帮助的话,请考虑引用下列论文
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            ```
         
     | 
| 64 | 
         
            +
            @inproceedings{
         
     | 
| 65 | 
         
            +
              zeng2023glm-130b,
         
     | 
| 66 | 
         
            +
              title={{GLM}-130B: An Open Bilingual Pre-trained Model},
         
     | 
| 67 | 
         
            +
              author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang},
         
     | 
| 68 | 
         
            +
              booktitle={The Eleventh International Conference on Learning Representations (ICLR)},
         
     | 
| 69 | 
         
            +
              year={2023},
         
     | 
| 70 | 
         
            +
              url={https://openreview.net/forum?id=-Aw0rrrPUF}
         
     | 
| 71 | 
         
            +
            }
         
     | 
| 72 | 
         
            +
            ```
         
     | 
| 73 | 
         
            +
            ```
         
     | 
| 74 | 
         
            +
            @inproceedings{du2022glm,
         
     | 
| 75 | 
         
            +
              title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
         
     | 
| 76 | 
         
            +
              author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
         
     | 
| 77 | 
         
            +
              booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
         
     | 
| 78 | 
         
            +
              pages={320--335},
         
     | 
| 79 | 
         
            +
              year={2022}
         
     | 
| 80 | 
         
            +
            }
         
     | 
| 81 | 
         
            +
            ```
         
     | 
    	
        config.json
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_name_or_path": "THUDM/chatglm-6b",
         
     | 
| 3 | 
         
            +
              "architectures": [
         
     | 
| 4 | 
         
            +
                "ChatGLMModel"
         
     | 
| 5 | 
         
            +
              ],
         
     | 
| 6 | 
         
            +
              "auto_map": {
         
     | 
| 7 | 
         
            +
                "AutoConfig": "configuration_chatglm.ChatGLMConfig",
         
     | 
| 8 | 
         
            +
                "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
         
     | 
| 9 | 
         
            +
                "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
         
     | 
| 10 | 
         
            +
              },
         
     | 
| 11 | 
         
            +
              "bos_token_id": 150004,
         
     | 
| 12 | 
         
            +
              "eos_token_id": 150005,
         
     | 
| 13 | 
         
            +
              "hidden_size": 4096,
         
     | 
| 14 | 
         
            +
              "inner_hidden_size": 16384,
         
     | 
| 15 | 
         
            +
              "layernorm_epsilon": 1e-05,
         
     | 
| 16 | 
         
            +
              "max_sequence_length": 2048,
         
     | 
| 17 | 
         
            +
              "model_type": "chatglm",
         
     | 
| 18 | 
         
            +
              "num_attention_heads": 32,
         
     | 
| 19 | 
         
            +
              "num_layers": 28,
         
     | 
| 20 | 
         
            +
              "position_encoding_2d": true,
         
     | 
| 21 | 
         
            +
              "torch_dtype": "float16",
         
     | 
| 22 | 
         
            +
              "transformers_version": "4.23.1",
         
     | 
| 23 | 
         
            +
              "use_cache": true,
         
     | 
| 24 | 
         
            +
              "vocab_size": 150528
         
     | 
| 25 | 
         
            +
            }
         
     | 
    	
        configuration_chatglm.py
    ADDED
    
    | 
         @@ -0,0 +1,92 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """ ChatGLM model configuration """
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig
         
     | 
| 4 | 
         
            +
            from transformers.utils import logging
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class ChatGLMConfig(PretrainedConfig):
         
     | 
| 10 | 
         
            +
                r"""
         
     | 
| 11 | 
         
            +
                This is the configuration class to store the configuration of a [`~ChatGLMModel`].
         
     | 
| 12 | 
         
            +
                It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
         
     | 
| 13 | 
         
            +
                architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
         
     | 
| 14 | 
         
            +
                the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Configuration objects inherit from  [`PretrainedConfig`] and can be used
         
     | 
| 17 | 
         
            +
                to control the model outputs. Read the documentation from  [`PretrainedConfig`]
         
     | 
| 18 | 
         
            +
                for more information.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    vocab_size (`int`, *optional*, defaults to 150528):
         
     | 
| 23 | 
         
            +
                        Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
         
     | 
| 24 | 
         
            +
                        `inputs_ids` passed when calling [`~ChatGLMModel`] or
         
     | 
| 25 | 
         
            +
                        [`~TFChatGLMModel`].
         
     | 
| 26 | 
         
            +
                    hidden_size (`int`, *optional*, defaults to 4096):
         
     | 
| 27 | 
         
            +
                        Dimension of the encoder layers and the pooler layer.
         
     | 
| 28 | 
         
            +
                    num_hidden_layers (`int`, *optional*, defaults to 28):
         
     | 
| 29 | 
         
            +
                        Number of hidden layers in the Transformer encoder.
         
     | 
| 30 | 
         
            +
                    num_attention_heads (`int`, *optional*, defaults to 32):
         
     | 
| 31 | 
         
            +
                        Number of attention heads for each attention layer in the Transformer encoder.
         
     | 
| 32 | 
         
            +
                    inner_hidden_size (`int`, *optional*, defaults to 16384):
         
     | 
| 33 | 
         
            +
                        Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
         
     | 
| 34 | 
         
            +
                    max_sequence_length (`int`, *optional*, defaults to 512):
         
     | 
| 35 | 
         
            +
                        The maximum sequence length that this model might ever be used with.
         
     | 
| 36 | 
         
            +
                        Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
         
     | 
| 37 | 
         
            +
                    layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
         
     | 
| 38 | 
         
            +
                        The epsilon used by the layer normalization layers.
         
     | 
| 39 | 
         
            +
                    use_cache (`bool`, *optional*, defaults to `True`):
         
     | 
| 40 | 
         
            +
                        Whether the model should return the last key/values attentions (not used by all models).
         
     | 
| 41 | 
         
            +
                    Example:
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                ```python
         
     | 
| 44 | 
         
            +
                >>> from configuration_chatglm import ChatGLMConfig
         
     | 
| 45 | 
         
            +
                >>> from modeling_chatglm import ChatGLMModel
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
         
     | 
| 48 | 
         
            +
                >>> configuration = ChatGLMConfig()
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
         
     | 
| 51 | 
         
            +
                >>> model = ChatGLMModel(configuration)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                >>> # Accessing the model configuration
         
     | 
| 54 | 
         
            +
                >>> configuration = model.config
         
     | 
| 55 | 
         
            +
                ```
         
     | 
| 56 | 
         
            +
            """
         
     | 
| 57 | 
         
            +
                model_type = "chatglm"
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def __init__(
         
     | 
| 60 | 
         
            +
                        self,
         
     | 
| 61 | 
         
            +
                        vocab_size=150528,
         
     | 
| 62 | 
         
            +
                        hidden_size=4096,
         
     | 
| 63 | 
         
            +
                        num_layers=28,
         
     | 
| 64 | 
         
            +
                        num_attention_heads=32,
         
     | 
| 65 | 
         
            +
                        layernorm_epsilon=1e-5,
         
     | 
| 66 | 
         
            +
                        use_cache=False,
         
     | 
| 67 | 
         
            +
                        bos_token_id=150004,
         
     | 
| 68 | 
         
            +
                        eos_token_id=150005,
         
     | 
| 69 | 
         
            +
                        pad_token_id=0,
         
     | 
| 70 | 
         
            +
                        max_sequence_length=2048,
         
     | 
| 71 | 
         
            +
                        inner_hidden_size=16384,
         
     | 
| 72 | 
         
            +
                        position_encoding_2d=True,
         
     | 
| 73 | 
         
            +
                        **kwargs
         
     | 
| 74 | 
         
            +
                ):
         
     | 
| 75 | 
         
            +
                    self.num_layers = num_layers
         
     | 
| 76 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 77 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 78 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 79 | 
         
            +
                    self.max_sequence_length = max_sequence_length
         
     | 
| 80 | 
         
            +
                    self.layernorm_epsilon = layernorm_epsilon
         
     | 
| 81 | 
         
            +
                    self.inner_hidden_size = inner_hidden_size
         
     | 
| 82 | 
         
            +
                    self.use_cache = use_cache
         
     | 
| 83 | 
         
            +
                    self.bos_token_id = bos_token_id
         
     | 
| 84 | 
         
            +
                    self.eos_token_id = eos_token_id
         
     | 
| 85 | 
         
            +
                    self.pad_token_id = pad_token_id
         
     | 
| 86 | 
         
            +
                    self.position_encoding_2d = position_encoding_2d
         
     | 
| 87 | 
         
            +
                    super().__init__(
         
     | 
| 88 | 
         
            +
                        pad_token_id=pad_token_id,
         
     | 
| 89 | 
         
            +
                        bos_token_id=bos_token_id,
         
     | 
| 90 | 
         
            +
                        eos_token_id=eos_token_id,
         
     | 
| 91 | 
         
            +
                        **kwargs
         
     | 
| 92 | 
         
            +
                    )
         
     | 
    	
        ice_text.model
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:99871e0c85db81ad7af1028854fd091cd5778c8414ae9d94bbbc10d02c831c21
         
     | 
| 3 | 
         
            +
            size 2699926
         
     | 
    	
        modeling_chatglm.py
    ADDED
    
    | 
         @@ -0,0 +1,1152 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """ PyTorch ChatGLM model. """
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import copy
         
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.utils.checkpoint
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from torch import nn
         
     | 
| 11 | 
         
            +
            from torch.nn import CrossEntropyLoss, LayerNorm
         
     | 
| 12 | 
         
            +
            from torch.nn.utils import skip_init
         
     | 
| 13 | 
         
            +
            from typing import Optional, Tuple, Union, List
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from transformers.utils import (
         
     | 
| 16 | 
         
            +
                add_code_sample_docstrings,
         
     | 
| 17 | 
         
            +
                add_start_docstrings,
         
     | 
| 18 | 
         
            +
                add_start_docstrings_to_model_forward,
         
     | 
| 19 | 
         
            +
            )
         
     | 
| 20 | 
         
            +
            from transformers.modeling_outputs import (
         
     | 
| 21 | 
         
            +
                BaseModelOutputWithPast,
         
     | 
| 22 | 
         
            +
                CausalLMOutputWithPast,
         
     | 
| 23 | 
         
            +
                BaseModelOutputWithPastAndCrossAttentions,
         
     | 
| 24 | 
         
            +
            )
         
     | 
| 25 | 
         
            +
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            from transformers.utils import logging
         
     | 
| 28 | 
         
            +
            from .configuration_chatglm import ChatGLMConfig
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            # flags required to enable jit fusion kernels
         
     | 
| 31 | 
         
            +
            torch._C._jit_set_profiling_mode(False)
         
     | 
| 32 | 
         
            +
            torch._C._jit_set_profiling_executor(False)
         
     | 
| 33 | 
         
            +
            torch._C._jit_override_can_fuse_on_cpu(True)
         
     | 
| 34 | 
         
            +
            torch._C._jit_override_can_fuse_on_gpu(True)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
         
     | 
| 39 | 
         
            +
            _CONFIG_FOR_DOC = "ChatGLM6BConfig"
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
         
     | 
| 42 | 
         
            +
                "THUDM/chatglm-6b",
         
     | 
| 43 | 
         
            +
                # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
         
     | 
| 44 | 
         
            +
            ]
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
         
     | 
| 48 | 
         
            +
                """Load tf checkpoints in a pytorch model."""
         
     | 
| 49 | 
         
            +
                try:
         
     | 
| 50 | 
         
            +
                    import re
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    import numpy as np
         
     | 
| 53 | 
         
            +
                    import tensorflow as tf
         
     | 
| 54 | 
         
            +
                except ImportError:
         
     | 
| 55 | 
         
            +
                    logger.error(
         
     | 
| 56 | 
         
            +
                        "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
         
     | 
| 57 | 
         
            +
                        "https://www.tensorflow.org/install/ for installation instructions."
         
     | 
| 58 | 
         
            +
                    )
         
     | 
| 59 | 
         
            +
                    raise
         
     | 
| 60 | 
         
            +
                tf_path = os.path.abspath(tf_checkpoint_path)
         
     | 
| 61 | 
         
            +
                logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
         
     | 
| 62 | 
         
            +
                # Load weights from TF model
         
     | 
| 63 | 
         
            +
                init_vars = tf.train.list_variables(tf_path)
         
     | 
| 64 | 
         
            +
                names = []
         
     | 
| 65 | 
         
            +
                arrays = []
         
     | 
| 66 | 
         
            +
                for name, shape in init_vars:
         
     | 
| 67 | 
         
            +
                    logger.info(f"Loading TF weight {name} with shape {shape}")
         
     | 
| 68 | 
         
            +
                    array = tf.train.load_variable(tf_path, name)
         
     | 
| 69 | 
         
            +
                    names.append(name)
         
     | 
| 70 | 
         
            +
                    arrays.append(array)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                for name, array in zip(names, arrays):
         
     | 
| 73 | 
         
            +
                    name = name.split("/")
         
     | 
| 74 | 
         
            +
                    # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
         
     | 
| 75 | 
         
            +
                    # which are not required for using pretrained model
         
     | 
| 76 | 
         
            +
                    if any(
         
     | 
| 77 | 
         
            +
                            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
         
     | 
| 78 | 
         
            +
                            for n in name
         
     | 
| 79 | 
         
            +
                    ):
         
     | 
| 80 | 
         
            +
                        logger.info(f"Skipping {'/'.join(name)}")
         
     | 
| 81 | 
         
            +
                        continue
         
     | 
| 82 | 
         
            +
                    pointer = model
         
     | 
| 83 | 
         
            +
                    for m_name in name:
         
     | 
| 84 | 
         
            +
                        if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
         
     | 
| 85 | 
         
            +
                            scope_names = re.split(r"_(\d+)", m_name)
         
     | 
| 86 | 
         
            +
                        else:
         
     | 
| 87 | 
         
            +
                            scope_names = [m_name]
         
     | 
| 88 | 
         
            +
                        if scope_names[0] == "kernel" or scope_names[0] == "gamma":
         
     | 
| 89 | 
         
            +
                            pointer = getattr(pointer, "weight")
         
     | 
| 90 | 
         
            +
                        elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
         
     | 
| 91 | 
         
            +
                            pointer = getattr(pointer, "bias")
         
     | 
| 92 | 
         
            +
                        elif scope_names[0] == "output_weights":
         
     | 
| 93 | 
         
            +
                            pointer = getattr(pointer, "weight")
         
     | 
| 94 | 
         
            +
                        elif scope_names[0] == "squad":
         
     | 
| 95 | 
         
            +
                            pointer = getattr(pointer, "classifier")
         
     | 
| 96 | 
         
            +
                        else:
         
     | 
| 97 | 
         
            +
                            try:
         
     | 
| 98 | 
         
            +
                                pointer = getattr(pointer, scope_names[0])
         
     | 
| 99 | 
         
            +
                            except AttributeError:
         
     | 
| 100 | 
         
            +
                                logger.info(f"Skipping {'/'.join(name)}")
         
     | 
| 101 | 
         
            +
                                continue
         
     | 
| 102 | 
         
            +
                        if len(scope_names) >= 2:
         
     | 
| 103 | 
         
            +
                            num = int(scope_names[1])
         
     | 
| 104 | 
         
            +
                            pointer = pointer[num]
         
     | 
| 105 | 
         
            +
                    if m_name[-11:] == "_embeddings":
         
     | 
| 106 | 
         
            +
                        pointer = getattr(pointer, "weight")
         
     | 
| 107 | 
         
            +
                    elif m_name == "kernel":
         
     | 
| 108 | 
         
            +
                        array = np.transpose(array)
         
     | 
| 109 | 
         
            +
                    try:
         
     | 
| 110 | 
         
            +
                        assert (
         
     | 
| 111 | 
         
            +
                                pointer.shape == array.shape
         
     | 
| 112 | 
         
            +
                        ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
         
     | 
| 113 | 
         
            +
                    except AssertionError as e:
         
     | 
| 114 | 
         
            +
                        e.args += (pointer.shape, array.shape)
         
     | 
| 115 | 
         
            +
                        raise
         
     | 
| 116 | 
         
            +
                    logger.info(f"Initialize PyTorch weight {name}")
         
     | 
| 117 | 
         
            +
                    pointer.data = torch.from_numpy(array)
         
     | 
| 118 | 
         
            +
                return model
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            @torch.jit.script
         
     | 
| 122 | 
         
            +
            def gelu_impl(x):
         
     | 
| 123 | 
         
            +
                """OpenAI's gelu implementation."""
         
     | 
| 124 | 
         
            +
                return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
         
     | 
| 125 | 
         
            +
                                                   (1.0 + 0.044715 * x * x)))
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            def gelu(x):
         
     | 
| 129 | 
         
            +
                return gelu_impl(x)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            class RotaryEmbedding(torch.nn.Module):
         
     | 
| 133 | 
         
            +
                def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
         
     | 
| 134 | 
         
            +
                    super().__init__()
         
     | 
| 135 | 
         
            +
                    inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
         
     | 
| 136 | 
         
            +
                    inv_freq = inv_freq.half()
         
     | 
| 137 | 
         
            +
                    self.learnable = learnable
         
     | 
| 138 | 
         
            +
                    if learnable:
         
     | 
| 139 | 
         
            +
                        self.inv_freq = torch.nn.Parameter(inv_freq)
         
     | 
| 140 | 
         
            +
                        self.max_seq_len_cached = None
         
     | 
| 141 | 
         
            +
                    else:
         
     | 
| 142 | 
         
            +
                        self.register_buffer('inv_freq', inv_freq)
         
     | 
| 143 | 
         
            +
                        self.max_seq_len_cached = None
         
     | 
| 144 | 
         
            +
                        self.cos_cached = None
         
     | 
| 145 | 
         
            +
                        self.sin_cached = None
         
     | 
| 146 | 
         
            +
                    self.precision = precision
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
         
     | 
| 149 | 
         
            +
                                          error_msgs):
         
     | 
| 150 | 
         
            +
                    pass
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                def forward(self, x, seq_dim=1, seq_len=None):
         
     | 
| 153 | 
         
            +
                    if seq_len is None:
         
     | 
| 154 | 
         
            +
                        seq_len = x.shape[seq_dim]
         
     | 
| 155 | 
         
            +
                    if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
         
     | 
| 156 | 
         
            +
                        self.max_seq_len_cached = None if self.learnable else seq_len
         
     | 
| 157 | 
         
            +
                        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
         
     | 
| 158 | 
         
            +
                        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
         
     | 
| 159 | 
         
            +
                        # Different from paper, but it uses a different permutation in order to obtain the same calculation
         
     | 
| 160 | 
         
            +
                        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
         
     | 
| 161 | 
         
            +
                        if self.precision == torch.bfloat16:
         
     | 
| 162 | 
         
            +
                            emb = emb.float()
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                        # [sx, 1 (b * np), hn]
         
     | 
| 165 | 
         
            +
                        cos_cached = emb.cos()[:, None, :]
         
     | 
| 166 | 
         
            +
                        sin_cached = emb.sin()[:, None, :]
         
     | 
| 167 | 
         
            +
                        if self.precision == torch.bfloat16:
         
     | 
| 168 | 
         
            +
                            cos_cached = cos_cached.bfloat16()
         
     | 
| 169 | 
         
            +
                            sin_cached = sin_cached.bfloat16()
         
     | 
| 170 | 
         
            +
                        if self.learnable:
         
     | 
| 171 | 
         
            +
                            return cos_cached, sin_cached
         
     | 
| 172 | 
         
            +
                        self.cos_cached, self.sin_cached = cos_cached, sin_cached
         
     | 
| 173 | 
         
            +
                    return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def rotate_half(x):
         
     | 
| 177 | 
         
            +
                x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
         
     | 
| 178 | 
         
            +
                return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in earlier torch versions
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            @torch.jit.script
         
     | 
| 182 | 
         
            +
            def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
         
     | 
| 183 | 
         
            +
                # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
         
     | 
| 184 | 
         
            +
                cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
         
     | 
| 185 | 
         
            +
                    F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
         
     | 
| 186 | 
         
            +
                q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
         
     | 
| 187 | 
         
            +
                return q, k
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
            def attention_fn(
         
     | 
| 191 | 
         
            +
                    self,
         
     | 
| 192 | 
         
            +
                    query_layer,
         
     | 
| 193 | 
         
            +
                    key_layer,
         
     | 
| 194 | 
         
            +
                    value_layer,
         
     | 
| 195 | 
         
            +
                    attention_mask,
         
     | 
| 196 | 
         
            +
                    hidden_size_per_partition,
         
     | 
| 197 | 
         
            +
                    layer_id,
         
     | 
| 198 | 
         
            +
                    layer_past=None,
         
     | 
| 199 | 
         
            +
                    scaling_attention_score=True,
         
     | 
| 200 | 
         
            +
                    use_cache=False,
         
     | 
| 201 | 
         
            +
            ):
         
     | 
| 202 | 
         
            +
                if layer_past is not None:
         
     | 
| 203 | 
         
            +
                    past_key, past_value = layer_past
         
     | 
| 204 | 
         
            +
                    key_layer = torch.cat((past_key, key_layer), dim=0)
         
     | 
| 205 | 
         
            +
                    value_layer = torch.cat((past_value, value_layer), dim=0)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
         
     | 
| 208 | 
         
            +
                seq_len, b, nh, hidden_size = key_layer.shape
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                if use_cache:
         
     | 
| 211 | 
         
            +
                    present = (key_layer, value_layer)
         
     | 
| 212 | 
         
            +
                else:
         
     | 
| 213 | 
         
            +
                    present = None
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                query_key_layer_scaling_coeff = float(layer_id + 1)
         
     | 
| 216 | 
         
            +
                if scaling_attention_score:
         
     | 
| 217 | 
         
            +
                    query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                # ===================================
         
     | 
| 220 | 
         
            +
                # Raw attention scores. [b, np, s, s]
         
     | 
| 221 | 
         
            +
                # ===================================
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                # [b, np, sq, sk]
         
     | 
| 224 | 
         
            +
                output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                # [sq, b, np, hn] -> [sq, b * np, hn]
         
     | 
| 227 | 
         
            +
                query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
         
     | 
| 228 | 
         
            +
                # [sk, b, np, hn] -> [sk, b * np, hn]
         
     | 
| 229 | 
         
            +
                key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                matmul_result = torch.empty(
         
     | 
| 232 | 
         
            +
                    output_size[0] * output_size[1],
         
     | 
| 233 | 
         
            +
                    output_size[2],
         
     | 
| 234 | 
         
            +
                    output_size[3],
         
     | 
| 235 | 
         
            +
                    dtype=query_layer.dtype,
         
     | 
| 236 | 
         
            +
                    device=query_layer.device,
         
     | 
| 237 | 
         
            +
                )
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                matmul_result = torch.baddbmm(
         
     | 
| 240 | 
         
            +
                    matmul_result,
         
     | 
| 241 | 
         
            +
                    query_layer.transpose(0, 1),  # [b * np, sq, hn]
         
     | 
| 242 | 
         
            +
                    key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
         
     | 
| 243 | 
         
            +
                    beta=0.0,
         
     | 
| 244 | 
         
            +
                    alpha=1.0,
         
     | 
| 245 | 
         
            +
                )
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                # change view to [b, np, sq, sk]
         
     | 
| 248 | 
         
            +
                attention_scores = matmul_result.view(*output_size)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                if self.scale_mask_softmax:
         
     | 
| 251 | 
         
            +
                    self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
         
     | 
| 252 | 
         
            +
                    attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
         
     | 
| 253 | 
         
            +
                else:
         
     | 
| 254 | 
         
            +
                    if not (attention_mask == 0).all():
         
     | 
| 255 | 
         
            +
                        # if auto-regressive, skip
         
     | 
| 256 | 
         
            +
                        attention_scores.masked_fill_(attention_mask, -10000.0)
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    attention_scores = attention_scores.float()
         
     | 
| 259 | 
         
            +
                    attention_scores = attention_scores * query_key_layer_scaling_coeff
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    attention_probs = F.softmax(attention_scores, dim=-1)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    attention_probs = attention_probs.half()
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                # =========================
         
     | 
| 266 | 
         
            +
                # Context layer. [sq, b, hp]
         
     | 
| 267 | 
         
            +
                # =========================
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                # value_layer -> context layer.
         
     | 
| 270 | 
         
            +
                # [sk, b, np, hn] --> [b, np, sq, hn]
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                # context layer shape: [b, np, sq, hn]
         
     | 
| 273 | 
         
            +
                output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                # change view [sk, b * np, hn]
         
     | 
| 276 | 
         
            +
                value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                # change view [b * np, sq, sk]
         
     | 
| 279 | 
         
            +
                attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                # matmul: [b * np, sq, hn]
         
     | 
| 282 | 
         
            +
                context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                # change view [b, np, sq, hn]
         
     | 
| 285 | 
         
            +
                context_layer = context_layer.view(*output_size)
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                # [b, np, sq, hn] --> [sq, b, np, hn]
         
     | 
| 288 | 
         
            +
                context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                # [sq, b, np, hn] --> [sq, b, hp]
         
     | 
| 291 | 
         
            +
                new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
         
     | 
| 292 | 
         
            +
                context_layer = context_layer.view(*new_context_layer_shape)
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                outputs = (context_layer, present, attention_probs)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                return outputs
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
            class SelfAttention(torch.nn.Module):
         
     | 
| 300 | 
         
            +
                def __init__(self, hidden_size, num_attention_heads,
         
     | 
| 301 | 
         
            +
                             layer_id, hidden_size_per_attention_head=None, bias=True,
         
     | 
| 302 | 
         
            +
                             params_dtype=torch.float, position_encoding_2d=True):
         
     | 
| 303 | 
         
            +
                    super(SelfAttention, self).__init__()
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    self.layer_id = layer_id
         
     | 
| 306 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 307 | 
         
            +
                    self.hidden_size_per_partition = hidden_size
         
     | 
| 308 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 309 | 
         
            +
                    self.num_attention_heads_per_partition = num_attention_heads
         
     | 
| 310 | 
         
            +
                    self.position_encoding_2d = position_encoding_2d
         
     | 
| 311 | 
         
            +
                    self.rotary_emb = RotaryEmbedding(
         
     | 
| 312 | 
         
            +
                        self.hidden_size // (self.num_attention_heads * 2)
         
     | 
| 313 | 
         
            +
                        if position_encoding_2d
         
     | 
| 314 | 
         
            +
                        else self.hidden_size // self.num_attention_heads,
         
     | 
| 315 | 
         
            +
                        base=10000,
         
     | 
| 316 | 
         
            +
                        precision=torch.half,
         
     | 
| 317 | 
         
            +
                        learnable=False,
         
     | 
| 318 | 
         
            +
                    )
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    self.scale_mask_softmax = None
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    if hidden_size_per_attention_head is None:
         
     | 
| 323 | 
         
            +
                        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
         
     | 
| 324 | 
         
            +
                    else:
         
     | 
| 325 | 
         
            +
                        self.hidden_size_per_attention_head = hidden_size_per_attention_head
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    # Strided linear layer.
         
     | 
| 330 | 
         
            +
                    self.query_key_value = skip_init(
         
     | 
| 331 | 
         
            +
                        torch.nn.Linear,
         
     | 
| 332 | 
         
            +
                        hidden_size,
         
     | 
| 333 | 
         
            +
                        3 * self.inner_hidden_size,
         
     | 
| 334 | 
         
            +
                        bias=bias,
         
     | 
| 335 | 
         
            +
                        dtype=params_dtype,
         
     | 
| 336 | 
         
            +
                    )
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    self.dense = skip_init(
         
     | 
| 339 | 
         
            +
                        torch.nn.Linear,
         
     | 
| 340 | 
         
            +
                        self.inner_hidden_size,
         
     | 
| 341 | 
         
            +
                        hidden_size,
         
     | 
| 342 | 
         
            +
                        bias=bias,
         
     | 
| 343 | 
         
            +
                        dtype=params_dtype,
         
     | 
| 344 | 
         
            +
                    )
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                @staticmethod
         
     | 
| 347 | 
         
            +
                def attention_mask_func(attention_scores, attention_mask):
         
     | 
| 348 | 
         
            +
                    attention_scores.masked_fill_(attention_mask, -10000.0)
         
     | 
| 349 | 
         
            +
                    return attention_scores
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                def split_tensor_along_last_dim(self, tensor, num_partitions,
         
     | 
| 352 | 
         
            +
                                                contiguous_split_chunks=False):
         
     | 
| 353 | 
         
            +
                    """Split a tensor along its last dimension.
         
     | 
| 354 | 
         
            +
                    Arguments:
         
     | 
| 355 | 
         
            +
                        tensor: input tensor.
         
     | 
| 356 | 
         
            +
                        num_partitions: number of partitions to split the tensor
         
     | 
| 357 | 
         
            +
                        contiguous_split_chunks: If True, make each chunk contiguous
         
     | 
| 358 | 
         
            +
                                                in memory.
         
     | 
| 359 | 
         
            +
                    """
         
     | 
| 360 | 
         
            +
                    # Get the size and dimension.
         
     | 
| 361 | 
         
            +
                    last_dim = tensor.dim() - 1
         
     | 
| 362 | 
         
            +
                    last_dim_size = tensor.size()[last_dim] // num_partitions
         
     | 
| 363 | 
         
            +
                    # Split.
         
     | 
| 364 | 
         
            +
                    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
         
     | 
| 365 | 
         
            +
                    # Note: torch.split does not create contiguous tensors by default.
         
     | 
| 366 | 
         
            +
                    if contiguous_split_chunks:
         
     | 
| 367 | 
         
            +
                        return tuple(chunk.contiguous() for chunk in tensor_list)
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    return tensor_list
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                def forward(
         
     | 
| 372 | 
         
            +
                        self,
         
     | 
| 373 | 
         
            +
                        hidden_states: torch.Tensor,
         
     | 
| 374 | 
         
            +
                        position_ids,
         
     | 
| 375 | 
         
            +
                        attention_mask: torch.Tensor,
         
     | 
| 376 | 
         
            +
                        layer_id,
         
     | 
| 377 | 
         
            +
                        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         
     | 
| 378 | 
         
            +
                        use_cache: bool = False,
         
     | 
| 379 | 
         
            +
                        output_attentions: bool = False,
         
     | 
| 380 | 
         
            +
                ):
         
     | 
| 381 | 
         
            +
                    """
         
     | 
| 382 | 
         
            +
                    hidden_states: [seq_len, batch, hidden_size]
         
     | 
| 383 | 
         
            +
                    attention_mask: [(1, 1), seq_len, seq_len]
         
     | 
| 384 | 
         
            +
                    """
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                    # [seq_len, batch, 3 * hidden_size]
         
     | 
| 387 | 
         
            +
                    mixed_raw_layer = self.query_key_value(hidden_states)
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
         
     | 
| 390 | 
         
            +
                    new_tensor_shape = mixed_raw_layer.size()[:-1] + (
         
     | 
| 391 | 
         
            +
                        self.num_attention_heads_per_partition,
         
     | 
| 392 | 
         
            +
                        3 * self.hidden_size_per_attention_head,
         
     | 
| 393 | 
         
            +
                    )
         
     | 
| 394 | 
         
            +
                    mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
         
     | 
| 397 | 
         
            +
                    (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    if self.position_encoding_2d:
         
     | 
| 400 | 
         
            +
                        q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
         
     | 
| 401 | 
         
            +
                        k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
         
     | 
| 402 | 
         
            +
                        cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
         
     | 
| 403 | 
         
            +
                        position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
         
     | 
| 404 | 
         
            +
                            position_ids[:, 1, :].transpose(0, 1).contiguous()
         
     | 
| 405 | 
         
            +
                        q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
         
     | 
| 406 | 
         
            +
                        q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
         
     | 
| 407 | 
         
            +
                        query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
         
     | 
| 408 | 
         
            +
                        key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
         
     | 
| 409 | 
         
            +
                    else:
         
     | 
| 410 | 
         
            +
                        position_ids = position_ids.transpose(0, 1)
         
     | 
| 411 | 
         
            +
                        cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
         
     | 
| 412 | 
         
            +
                        # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
         
     | 
| 413 | 
         
            +
                        query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                    # [seq_len, batch, hidden_size]
         
     | 
| 416 | 
         
            +
                    context_layer, present, attention_probs = attention_fn(
         
     | 
| 417 | 
         
            +
                        self=self,
         
     | 
| 418 | 
         
            +
                        query_layer=query_layer,
         
     | 
| 419 | 
         
            +
                        key_layer=key_layer,
         
     | 
| 420 | 
         
            +
                        value_layer=value_layer,
         
     | 
| 421 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 422 | 
         
            +
                        hidden_size_per_partition=self.hidden_size_per_partition,
         
     | 
| 423 | 
         
            +
                        layer_id=layer_id,
         
     | 
| 424 | 
         
            +
                        layer_past=layer_past,
         
     | 
| 425 | 
         
            +
                        use_cache=use_cache
         
     | 
| 426 | 
         
            +
                    )
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    output = self.dense(context_layer)
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                    outputs = (output, present)
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                    if output_attentions:
         
     | 
| 433 | 
         
            +
                        outputs += (attention_probs,)
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                    return outputs  # output, present, attention_probs
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
            class GEGLU(torch.nn.Module):
         
     | 
| 439 | 
         
            +
                def __init__(self):
         
     | 
| 440 | 
         
            +
                    super().__init__()
         
     | 
| 441 | 
         
            +
                    self.activation_fn = F.gelu
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                def forward(self, x):
         
     | 
| 444 | 
         
            +
                    # dim=-1 breaks in jit for pt<1.10
         
     | 
| 445 | 
         
            +
                    x1, x2 = x.chunk(2, dim=(x.ndim - 1))
         
     | 
| 446 | 
         
            +
                    return x1 * self.activation_fn(x2)
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
            class GLU(torch.nn.Module):
         
     | 
| 450 | 
         
            +
                def __init__(self, hidden_size, inner_hidden_size=None,
         
     | 
| 451 | 
         
            +
                             layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
         
     | 
| 452 | 
         
            +
                    super(GLU, self).__init__()
         
     | 
| 453 | 
         
            +
                    self.layer_id = layer_id
         
     | 
| 454 | 
         
            +
                    self.activation_func = activation_func
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
                    # Project to 4h.
         
     | 
| 457 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 458 | 
         
            +
                    if inner_hidden_size is None:
         
     | 
| 459 | 
         
            +
                        inner_hidden_size = 4 * hidden_size
         
     | 
| 460 | 
         
            +
                    self.inner_hidden_size = inner_hidden_size
         
     | 
| 461 | 
         
            +
                    self.dense_h_to_4h = skip_init(
         
     | 
| 462 | 
         
            +
                        torch.nn.Linear,
         
     | 
| 463 | 
         
            +
                        self.hidden_size,
         
     | 
| 464 | 
         
            +
                        self.inner_hidden_size,
         
     | 
| 465 | 
         
            +
                        bias=bias,
         
     | 
| 466 | 
         
            +
                        dtype=params_dtype,
         
     | 
| 467 | 
         
            +
                    )
         
     | 
| 468 | 
         
            +
                    # Project back to h.
         
     | 
| 469 | 
         
            +
                    self.dense_4h_to_h = skip_init(
         
     | 
| 470 | 
         
            +
                        torch.nn.Linear,
         
     | 
| 471 | 
         
            +
                        self.inner_hidden_size,
         
     | 
| 472 | 
         
            +
                        self.hidden_size,
         
     | 
| 473 | 
         
            +
                        bias=bias,
         
     | 
| 474 | 
         
            +
                        dtype=params_dtype,
         
     | 
| 475 | 
         
            +
                    )
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 478 | 
         
            +
                    """
         
     | 
| 479 | 
         
            +
                    hidden_states: [seq_len, batch, hidden_size]
         
     | 
| 480 | 
         
            +
                    """
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    # [seq_len, batch, inner_hidden_size]
         
     | 
| 483 | 
         
            +
                    intermediate_parallel = self.dense_h_to_4h(hidden_states)
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                    intermediate_parallel = self.activation_func(intermediate_parallel)
         
     | 
| 486 | 
         
            +
             
     | 
| 487 | 
         
            +
                    output = self.dense_4h_to_h(intermediate_parallel)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    return output
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
            class GLMBlock(torch.nn.Module):
         
     | 
| 493 | 
         
            +
                def __init__(
         
     | 
| 494 | 
         
            +
                        self,
         
     | 
| 495 | 
         
            +
                        hidden_size,
         
     | 
| 496 | 
         
            +
                        num_attention_heads,
         
     | 
| 497 | 
         
            +
                        layernorm_epsilon,
         
     | 
| 498 | 
         
            +
                        layer_id,
         
     | 
| 499 | 
         
            +
                        inner_hidden_size=None,
         
     | 
| 500 | 
         
            +
                        hidden_size_per_attention_head=None,
         
     | 
| 501 | 
         
            +
                        layernorm=LayerNorm,
         
     | 
| 502 | 
         
            +
                        use_bias=True,
         
     | 
| 503 | 
         
            +
                        params_dtype=torch.float,
         
     | 
| 504 | 
         
            +
                        num_layers=28,
         
     | 
| 505 | 
         
            +
                        position_encoding_2d=True
         
     | 
| 506 | 
         
            +
                ):
         
     | 
| 507 | 
         
            +
                    super(GLMBlock, self).__init__()
         
     | 
| 508 | 
         
            +
                    # Set output layer initialization if not provided.
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                    self.layer_id = layer_id
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                    # Layernorm on the input data.
         
     | 
| 513 | 
         
            +
                    self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                    self.position_encoding_2d = position_encoding_2d
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
                    # Self attention.
         
     | 
| 518 | 
         
            +
                    self.attention = SelfAttention(
         
     | 
| 519 | 
         
            +
                        hidden_size,
         
     | 
| 520 | 
         
            +
                        num_attention_heads,
         
     | 
| 521 | 
         
            +
                        layer_id,
         
     | 
| 522 | 
         
            +
                        hidden_size_per_attention_head=hidden_size_per_attention_head,
         
     | 
| 523 | 
         
            +
                        bias=use_bias,
         
     | 
| 524 | 
         
            +
                        params_dtype=params_dtype,
         
     | 
| 525 | 
         
            +
                        position_encoding_2d=self.position_encoding_2d
         
     | 
| 526 | 
         
            +
                    )
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                    # Layernorm on the input data.
         
     | 
| 529 | 
         
            +
                    self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    self.num_layers = num_layers
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                    # GLU
         
     | 
| 534 | 
         
            +
                    self.mlp = GLU(
         
     | 
| 535 | 
         
            +
                        hidden_size,
         
     | 
| 536 | 
         
            +
                        inner_hidden_size=inner_hidden_size,
         
     | 
| 537 | 
         
            +
                        bias=use_bias,
         
     | 
| 538 | 
         
            +
                        layer_id=layer_id,
         
     | 
| 539 | 
         
            +
                        params_dtype=params_dtype,
         
     | 
| 540 | 
         
            +
                    )
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
                def forward(
         
     | 
| 543 | 
         
            +
                        self,
         
     | 
| 544 | 
         
            +
                        hidden_states: torch.Tensor,
         
     | 
| 545 | 
         
            +
                        position_ids,
         
     | 
| 546 | 
         
            +
                        attention_mask: torch.Tensor,
         
     | 
| 547 | 
         
            +
                        layer_id,
         
     | 
| 548 | 
         
            +
                        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         
     | 
| 549 | 
         
            +
                        use_cache: bool = False,
         
     | 
| 550 | 
         
            +
                        output_attentions: bool = False,
         
     | 
| 551 | 
         
            +
                ):
         
     | 
| 552 | 
         
            +
                    """
         
     | 
| 553 | 
         
            +
                    hidden_states: [seq_len, batch, hidden_size]
         
     | 
| 554 | 
         
            +
                    attention_mask: [(1, 1), seq_len, seq_len]
         
     | 
| 555 | 
         
            +
                    """
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
                    # Layer norm at the begining of the transformer layer.
         
     | 
| 558 | 
         
            +
                    # [seq_len, batch, hidden_size]
         
     | 
| 559 | 
         
            +
                    attention_input = self.input_layernorm(hidden_states)
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    # Self attention.
         
     | 
| 562 | 
         
            +
                    attention_outputs = self.attention(
         
     | 
| 563 | 
         
            +
                        attention_input,
         
     | 
| 564 | 
         
            +
                        position_ids,
         
     | 
| 565 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 566 | 
         
            +
                        layer_id=layer_id,
         
     | 
| 567 | 
         
            +
                        layer_past=layer_past,
         
     | 
| 568 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 569 | 
         
            +
                        output_attentions=output_attentions
         
     | 
| 570 | 
         
            +
                    )
         
     | 
| 571 | 
         
            +
             
     | 
| 572 | 
         
            +
                    attention_output = attention_outputs[0]
         
     | 
| 573 | 
         
            +
             
     | 
| 574 | 
         
            +
                    outputs = attention_outputs[1:]
         
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
                    # Residual connection.
         
     | 
| 577 | 
         
            +
                    alpha = (2 * self.num_layers) ** 0.5
         
     | 
| 578 | 
         
            +
                    hidden_states = attention_input * alpha + attention_output
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
                    mlp_input = self.post_attention_layernorm(hidden_states)
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                    # MLP.
         
     | 
| 583 | 
         
            +
                    mlp_output = self.mlp(mlp_input)
         
     | 
| 584 | 
         
            +
             
     | 
| 585 | 
         
            +
                    # Second residual connection.
         
     | 
| 586 | 
         
            +
                    output = mlp_input * alpha + mlp_output
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                    if use_cache:
         
     | 
| 589 | 
         
            +
                        outputs = (output,) + outputs
         
     | 
| 590 | 
         
            +
                    else:
         
     | 
| 591 | 
         
            +
                        outputs = (output,) + outputs[1:]
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                    return outputs  # hidden_states, present, attentions
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
             
     | 
| 596 | 
         
            +
            class ChatGLMPreTrainedModel(PreTrainedModel):
         
     | 
| 597 | 
         
            +
                """
         
     | 
| 598 | 
         
            +
                An abstract class to handle weights initialization and
         
     | 
| 599 | 
         
            +
                a simple interface for downloading and loading pretrained models.
         
     | 
| 600 | 
         
            +
                """
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
                is_parallelizable = True
         
     | 
| 603 | 
         
            +
                supports_gradient_checkpointing = False
         
     | 
| 604 | 
         
            +
                config_class = ChatGLMConfig
         
     | 
| 605 | 
         
            +
                base_model_prefix = "transformer"
         
     | 
| 606 | 
         
            +
                _no_split_modules = ["GLM6BBlock"]
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                def __init__(self, *inputs, **kwargs):
         
     | 
| 609 | 
         
            +
                    super().__init__(*inputs, **kwargs)
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                def _init_weights(self, module: nn.Module):
         
     | 
| 612 | 
         
            +
                    """Initialize the weights."""
         
     | 
| 613 | 
         
            +
                    return
         
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
            CHATGLM_6B_START_DOCSTRING = r"""
         
     | 
| 617 | 
         
            +
                This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
         
     | 
| 618 | 
         
            +
                Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
         
     | 
| 619 | 
         
            +
                usage and behavior.
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                Parameters:
         
     | 
| 622 | 
         
            +
                    config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
         
     | 
| 623 | 
         
            +
                        Initializing with a config file does not load the weights associated with the model, only the configuration.
         
     | 
| 624 | 
         
            +
                        Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
         
     | 
| 625 | 
         
            +
            """
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
            CHATGLM_6B_INPUTS_DOCSTRING = r"""
         
     | 
| 628 | 
         
            +
                Args:
         
     | 
| 629 | 
         
            +
                    input_ids (`torch.LongTensor` of shape `({0})`):
         
     | 
| 630 | 
         
            +
                        Indices of input sequence tokens in the vocabulary.
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
                        Indices can be obtained using [`ChatGLM6BTokenizer`].
         
     | 
| 633 | 
         
            +
                        See [`PreTrainedTokenizer.encode`] and
         
     | 
| 634 | 
         
            +
                        [`PreTrainedTokenizer.__call__`] for details.
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                        [What are input IDs?](../glossary#input-ids)
         
     | 
| 637 | 
         
            +
                    attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
         
     | 
| 638 | 
         
            +
                        Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
         
     | 
| 639 | 
         
            +
             
     | 
| 640 | 
         
            +
                        - 1 for tokens that are **not masked**,
         
     | 
| 641 | 
         
            +
                        - 0 for tokens that are **masked**.
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
                        [What are attention masks?](../glossary#attention-mask)
         
     | 
| 644 | 
         
            +
                    token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
         
     | 
| 645 | 
         
            +
                        Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
         
     | 
| 646 | 
         
            +
             
     | 
| 647 | 
         
            +
                        - 0 corresponds to a *sentence A* token,
         
     | 
| 648 | 
         
            +
                        - 1 corresponds to a *sentence B* token.
         
     | 
| 649 | 
         
            +
             
     | 
| 650 | 
         
            +
                        [What are token type IDs?](../glossary#token-type-ids)
         
     | 
| 651 | 
         
            +
                    position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
         
     | 
| 652 | 
         
            +
                        Indices of positions of each input sequence tokens in the position embeddings.
         
     | 
| 653 | 
         
            +
                        Selected in the range `[0, config.max_position_embeddings - 1]`.
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                        [What are position IDs?](../glossary#position-ids)
         
     | 
| 656 | 
         
            +
                    head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
         
     | 
| 657 | 
         
            +
                        Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
         
     | 
| 658 | 
         
            +
             
     | 
| 659 | 
         
            +
                        - 1 indicates the head is **not masked**,
         
     | 
| 660 | 
         
            +
                        - 0 indicates the head is **masked**.
         
     | 
| 661 | 
         
            +
             
     | 
| 662 | 
         
            +
                    inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
         
     | 
| 663 | 
         
            +
                        Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
         
     | 
| 664 | 
         
            +
                        This is useful if you want more control over how to convert *input_ids* indices into associated vectors
         
     | 
| 665 | 
         
            +
                        than the model's internal embedding lookup matrix.
         
     | 
| 666 | 
         
            +
                    output_attentions (`bool`, *optional*):
         
     | 
| 667 | 
         
            +
                        Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
         
     | 
| 668 | 
         
            +
                        tensors for more detail.
         
     | 
| 669 | 
         
            +
                    output_hidden_states (`bool`, *optional*):
         
     | 
| 670 | 
         
            +
                        Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
         
     | 
| 671 | 
         
            +
                        more detail.
         
     | 
| 672 | 
         
            +
                    return_dict (`bool`, *optional*):
         
     | 
| 673 | 
         
            +
                        Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
         
     | 
| 674 | 
         
            +
            """
         
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
             
     | 
| 677 | 
         
            +
            @add_start_docstrings(
         
     | 
| 678 | 
         
            +
                "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",
         
     | 
| 679 | 
         
            +
                CHATGLM_6B_START_DOCSTRING,
         
     | 
| 680 | 
         
            +
            )
         
     | 
| 681 | 
         
            +
            class ChatGLMModel(ChatGLMPreTrainedModel):
         
     | 
| 682 | 
         
            +
                """
         
     | 
| 683 | 
         
            +
             
     | 
| 684 | 
         
            +
                The model can behave as an encoder (with only self-attention) as well
         
     | 
| 685 | 
         
            +
                as a decoder, in which case a layer of cross-attention is added between
         
     | 
| 686 | 
         
            +
                the self-attention layers, following the architecture described in [Attention is
         
     | 
| 687 | 
         
            +
                all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
         
     | 
| 688 | 
         
            +
                Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                To behave as an decoder the model needs to be initialized with the
         
     | 
| 691 | 
         
            +
                `is_decoder` argument of the configuration set to `True`.
         
     | 
| 692 | 
         
            +
                To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
         
     | 
| 693 | 
         
            +
                argument and `add_cross_attention` set to `True`; an
         
     | 
| 694 | 
         
            +
                `encoder_hidden_states` is then expected as an input to the forward pass.
         
     | 
| 695 | 
         
            +
                """
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                def __init__(self, config: ChatGLMConfig):
         
     | 
| 698 | 
         
            +
                    super().__init__(config)
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                    # recording parameters
         
     | 
| 701 | 
         
            +
                    self.max_sequence_length = config.max_sequence_length
         
     | 
| 702 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 703 | 
         
            +
                    self.params_dtype = torch.half
         
     | 
| 704 | 
         
            +
                    self.num_attention_heads = config.num_attention_heads
         
     | 
| 705 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 706 | 
         
            +
                    self.num_layers = config.num_layers
         
     | 
| 707 | 
         
            +
                    self.layernorm_epsilon = config.layernorm_epsilon
         
     | 
| 708 | 
         
            +
                    self.inner_hidden_size = config.inner_hidden_size
         
     | 
| 709 | 
         
            +
                    self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
         
     | 
| 710 | 
         
            +
                    self.position_encoding_2d = config.position_encoding_2d
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                    self.word_embeddings = skip_init(
         
     | 
| 713 | 
         
            +
                        torch.nn.Embedding,
         
     | 
| 714 | 
         
            +
                        num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
         
     | 
| 715 | 
         
            +
                        dtype=self.params_dtype
         
     | 
| 716 | 
         
            +
                    )
         
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
                    def get_layer(layer_id):
         
     | 
| 719 | 
         
            +
                        return GLMBlock(
         
     | 
| 720 | 
         
            +
                            self.hidden_size,
         
     | 
| 721 | 
         
            +
                            self.num_attention_heads,
         
     | 
| 722 | 
         
            +
                            self.layernorm_epsilon,
         
     | 
| 723 | 
         
            +
                            layer_id,
         
     | 
| 724 | 
         
            +
                            inner_hidden_size=self.inner_hidden_size,
         
     | 
| 725 | 
         
            +
                            hidden_size_per_attention_head=self.hidden_size_per_attention_head,
         
     | 
| 726 | 
         
            +
                            layernorm=LayerNorm,
         
     | 
| 727 | 
         
            +
                            use_bias=True,
         
     | 
| 728 | 
         
            +
                            params_dtype=self.params_dtype,
         
     | 
| 729 | 
         
            +
                            position_encoding_2d=self.position_encoding_2d,
         
     | 
| 730 | 
         
            +
                        )
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
                    self.layers = torch.nn.ModuleList(
         
     | 
| 733 | 
         
            +
                        [get_layer(layer_id) for layer_id in range(self.num_layers)]
         
     | 
| 734 | 
         
            +
                    )
         
     | 
| 735 | 
         
            +
             
     | 
| 736 | 
         
            +
                    # Final layer norm before output.
         
     | 
| 737 | 
         
            +
                    self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
         
     | 
| 738 | 
         
            +
             
     | 
| 739 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 740 | 
         
            +
                    return self.word_embeddings
         
     | 
| 741 | 
         
            +
             
     | 
| 742 | 
         
            +
                def set_input_embeddings(self, new_embeddings: torch.Tensor):
         
     | 
| 743 | 
         
            +
                    self.word_embeddings = new_embeddings
         
     | 
| 744 | 
         
            +
             
     | 
| 745 | 
         
            +
                @staticmethod
         
     | 
| 746 | 
         
            +
                def get_masks(seq, device):
         
     | 
| 747 | 
         
            +
                    context_length = seq.index(150004) + 1
         
     | 
| 748 | 
         
            +
             
     | 
| 749 | 
         
            +
                    attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
         
     | 
| 750 | 
         
            +
                    attention_mask.tril_()
         
     | 
| 751 | 
         
            +
                    attention_mask[..., :context_length - 1] = 1
         
     | 
| 752 | 
         
            +
                    attention_mask.unsqueeze_(1)
         
     | 
| 753 | 
         
            +
                    attention_mask = (attention_mask < 0.5).bool()
         
     | 
| 754 | 
         
            +
             
     | 
| 755 | 
         
            +
                    return attention_mask
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
                def get_position_ids(self, seq, mask_position, device, gmask=False):
         
     | 
| 758 | 
         
            +
                    context_length = seq.index(150004) + 1
         
     | 
| 759 | 
         
            +
                    if self.position_encoding_2d:
         
     | 
| 760 | 
         
            +
                        seq_length = seq.index(150004)
         
     | 
| 761 | 
         
            +
                        position_ids = torch.arange(context_length, dtype=torch.long, device=device)
         
     | 
| 762 | 
         
            +
                        if not gmask:
         
     | 
| 763 | 
         
            +
                            position_ids[seq_length:] = mask_position
         
     | 
| 764 | 
         
            +
                        block_position_ids = torch.cat((
         
     | 
| 765 | 
         
            +
                            torch.zeros(seq_length, dtype=torch.long, device=device),
         
     | 
| 766 | 
         
            +
                            torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
         
     | 
| 767 | 
         
            +
                        ))
         
     | 
| 768 | 
         
            +
                        position_ids = torch.stack((position_ids, block_position_ids), dim=0)
         
     | 
| 769 | 
         
            +
                    else:
         
     | 
| 770 | 
         
            +
                        position_ids = torch.arange(context_length, dtype=torch.long, device=device)
         
     | 
| 771 | 
         
            +
                        if not gmask:
         
     | 
| 772 | 
         
            +
                            position_ids[context_length - 1:] = mask_position
         
     | 
| 773 | 
         
            +
             
     | 
| 774 | 
         
            +
                    position_ids = position_ids.unsqueeze(0)
         
     | 
| 775 | 
         
            +
             
     | 
| 776 | 
         
            +
                    return position_ids
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
                @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
         
     | 
| 779 | 
         
            +
                @add_code_sample_docstrings(
         
     | 
| 780 | 
         
            +
                    checkpoint=_CHECKPOINT_FOR_DOC,
         
     | 
| 781 | 
         
            +
                    output_type=BaseModelOutputWithPastAndCrossAttentions,
         
     | 
| 782 | 
         
            +
                    config_class=_CONFIG_FOR_DOC,
         
     | 
| 783 | 
         
            +
                )
         
     | 
| 784 | 
         
            +
                def forward(
         
     | 
| 785 | 
         
            +
                        self,
         
     | 
| 786 | 
         
            +
                        input_ids: Optional[torch.LongTensor] = None,
         
     | 
| 787 | 
         
            +
                        position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 788 | 
         
            +
                        attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 789 | 
         
            +
                        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
         
     | 
| 790 | 
         
            +
                        inputs_embeds: Optional[torch.LongTensor] = None,
         
     | 
| 791 | 
         
            +
                        use_cache: Optional[bool] = None,
         
     | 
| 792 | 
         
            +
                        output_attentions: Optional[bool] = None,
         
     | 
| 793 | 
         
            +
                        output_hidden_states: Optional[bool] = None,
         
     | 
| 794 | 
         
            +
                        return_dict: Optional[bool] = None,
         
     | 
| 795 | 
         
            +
                ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
         
     | 
| 796 | 
         
            +
             
     | 
| 797 | 
         
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 798 | 
         
            +
                    output_hidden_states = (
         
     | 
| 799 | 
         
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 800 | 
         
            +
                    )
         
     | 
| 801 | 
         
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         
     | 
| 802 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                    if input_ids is not None and inputs_embeds is not None:
         
     | 
| 805 | 
         
            +
                        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         
     | 
| 806 | 
         
            +
                    elif input_ids is not None:
         
     | 
| 807 | 
         
            +
                        batch_size, seq_length = input_ids.shape[:2]
         
     | 
| 808 | 
         
            +
                    elif inputs_embeds is not None:
         
     | 
| 809 | 
         
            +
                        batch_size, seq_length, _ = inputs_embeds.shape[:2]
         
     | 
| 810 | 
         
            +
                    else:
         
     | 
| 811 | 
         
            +
                        raise ValueError("You have to specify either input_ids or inputs_embeds")
         
     | 
| 812 | 
         
            +
             
     | 
| 813 | 
         
            +
                    if past_key_values is None:
         
     | 
| 814 | 
         
            +
                        past_key_values = tuple([None] * len(self.layers))
         
     | 
| 815 | 
         
            +
             
     | 
| 816 | 
         
            +
                        MASK, gMASK = 150000, 150001
         
     | 
| 817 | 
         
            +
                        mask_token = MASK if MASK in input_ids else gMASK
         
     | 
| 818 | 
         
            +
                        use_gmask = False if MASK in input_ids else gMASK
         
     | 
| 819 | 
         
            +
                        seq = input_ids[0].tolist()
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
                        mask_position = seq.index(mask_token)
         
     | 
| 822 | 
         
            +
             
     | 
| 823 | 
         
            +
                        if attention_mask is None:
         
     | 
| 824 | 
         
            +
                            attention_mask = self.get_masks(
         
     | 
| 825 | 
         
            +
                                seq=seq,
         
     | 
| 826 | 
         
            +
                                device=input_ids.device
         
     | 
| 827 | 
         
            +
                            )
         
     | 
| 828 | 
         
            +
             
     | 
| 829 | 
         
            +
                        if position_ids is None:
         
     | 
| 830 | 
         
            +
                            position_ids = self.get_position_ids(
         
     | 
| 831 | 
         
            +
                                seq=seq,
         
     | 
| 832 | 
         
            +
                                mask_position=mask_position,
         
     | 
| 833 | 
         
            +
                                device=input_ids.device,
         
     | 
| 834 | 
         
            +
                                gmask=use_gmask
         
     | 
| 835 | 
         
            +
                            )
         
     | 
| 836 | 
         
            +
             
     | 
| 837 | 
         
            +
                    if inputs_embeds is None:
         
     | 
| 838 | 
         
            +
                        inputs_embeds = self.word_embeddings(input_ids)
         
     | 
| 839 | 
         
            +
             
     | 
| 840 | 
         
            +
                    # [seq_len, batch, hidden_size]
         
     | 
| 841 | 
         
            +
                    hidden_states = inputs_embeds.transpose(0, 1)
         
     | 
| 842 | 
         
            +
             
     | 
| 843 | 
         
            +
                    presents = () if use_cache else None
         
     | 
| 844 | 
         
            +
                    all_self_attentions = () if output_attentions else None
         
     | 
| 845 | 
         
            +
                    all_hidden_states = () if output_hidden_states else None
         
     | 
| 846 | 
         
            +
             
     | 
| 847 | 
         
            +
                    seq_length_with_past = seq_length
         
     | 
| 848 | 
         
            +
                    past_key_values_length = 0
         
     | 
| 849 | 
         
            +
                    if past_key_values[0] is not None:
         
     | 
| 850 | 
         
            +
                        past_key_values_length = past_key_values[0][0].shape[0]
         
     | 
| 851 | 
         
            +
                        seq_length_with_past = seq_length_with_past + past_key_values_length
         
     | 
| 852 | 
         
            +
                    if attention_mask is None:
         
     | 
| 853 | 
         
            +
                        attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
         
     | 
| 854 | 
         
            +
             
     | 
| 855 | 
         
            +
                    else:
         
     | 
| 856 | 
         
            +
                        attention_mask = attention_mask.to(input_ids.device)
         
     | 
| 857 | 
         
            +
             
     | 
| 858 | 
         
            +
                    for i, layer in enumerate(self.layers):
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
            +
                        if output_hidden_states:
         
     | 
| 861 | 
         
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         
     | 
| 862 | 
         
            +
             
     | 
| 863 | 
         
            +
                        layer_ret = layer(
         
     | 
| 864 | 
         
            +
                            hidden_states,
         
     | 
| 865 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 866 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 867 | 
         
            +
                            layer_id=torch.tensor(i),
         
     | 
| 868 | 
         
            +
                            layer_past=past_key_values[i],
         
     | 
| 869 | 
         
            +
                            use_cache=use_cache,
         
     | 
| 870 | 
         
            +
                            output_attentions=output_attentions
         
     | 
| 871 | 
         
            +
                        )
         
     | 
| 872 | 
         
            +
             
     | 
| 873 | 
         
            +
                        hidden_states = layer_ret[0]
         
     | 
| 874 | 
         
            +
             
     | 
| 875 | 
         
            +
                        if use_cache:
         
     | 
| 876 | 
         
            +
                            presents = presents + (layer_ret[1],)
         
     | 
| 877 | 
         
            +
             
     | 
| 878 | 
         
            +
                        if output_attentions:
         
     | 
| 879 | 
         
            +
                            all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
         
     | 
| 880 | 
         
            +
             
     | 
| 881 | 
         
            +
                    # Final layer norm.
         
     | 
| 882 | 
         
            +
                    hidden_states = self.final_layernorm(hidden_states)
         
     | 
| 883 | 
         
            +
             
     | 
| 884 | 
         
            +
                    if output_hidden_states:
         
     | 
| 885 | 
         
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         
     | 
| 886 | 
         
            +
             
     | 
| 887 | 
         
            +
                    if not return_dict:
         
     | 
| 888 | 
         
            +
                        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
         
     | 
| 889 | 
         
            +
             
     | 
| 890 | 
         
            +
                    return BaseModelOutputWithPast(
         
     | 
| 891 | 
         
            +
                        last_hidden_state=hidden_states,
         
     | 
| 892 | 
         
            +
                        past_key_values=presents,
         
     | 
| 893 | 
         
            +
                        hidden_states=all_hidden_states,
         
     | 
| 894 | 
         
            +
                        attentions=all_self_attentions,
         
     | 
| 895 | 
         
            +
                    )
         
     | 
| 896 | 
         
            +
             
     | 
| 897 | 
         
            +
             
     | 
| 898 | 
         
            +
            class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
         
     | 
| 899 | 
         
            +
                def __init__(self, config):
         
     | 
| 900 | 
         
            +
                    super().__init__(config)
         
     | 
| 901 | 
         
            +
             
     | 
| 902 | 
         
            +
                    # self.hidden_size = config.hidden_size
         
     | 
| 903 | 
         
            +
                    # self.params_dtype = torch.half
         
     | 
| 904 | 
         
            +
                    # self.vocab_size = config.vocab_size
         
     | 
| 905 | 
         
            +
                    self.max_sequence_length = config.max_sequence_length
         
     | 
| 906 | 
         
            +
             
     | 
| 907 | 
         
            +
                    self.position_encoding_2d = config.position_encoding_2d
         
     | 
| 908 | 
         
            +
             
     | 
| 909 | 
         
            +
                    self.transformer = ChatGLMModel(config)
         
     | 
| 910 | 
         
            +
             
     | 
| 911 | 
         
            +
                    self.lm_head = skip_init(
         
     | 
| 912 | 
         
            +
                        nn.Linear,
         
     | 
| 913 | 
         
            +
                        config.hidden_size,
         
     | 
| 914 | 
         
            +
                        config.vocab_size,
         
     | 
| 915 | 
         
            +
                        bias=False,
         
     | 
| 916 | 
         
            +
                        dtype=torch.half
         
     | 
| 917 | 
         
            +
                    )
         
     | 
| 918 | 
         
            +
             
     | 
| 919 | 
         
            +
                def get_output_embeddings(self):
         
     | 
| 920 | 
         
            +
                    return self.lm_head
         
     | 
| 921 | 
         
            +
             
     | 
| 922 | 
         
            +
                def set_output_embeddings(self, new_embeddings):
         
     | 
| 923 | 
         
            +
                    self.lm_head = new_embeddings
         
     | 
| 924 | 
         
            +
             
     | 
| 925 | 
         
            +
                def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
         
     | 
| 926 | 
         
            +
                    attention_mask = torch.ones((1, context_length, context_length), device=device)
         
     | 
| 927 | 
         
            +
                    attention_mask.tril_()
         
     | 
| 928 | 
         
            +
                    attention_mask[..., :context_length - 1] = 1
         
     | 
| 929 | 
         
            +
                    attention_mask.unsqueeze_(1)
         
     | 
| 930 | 
         
            +
                    attention_mask = (attention_mask < 0.5).bool()
         
     | 
| 931 | 
         
            +
             
     | 
| 932 | 
         
            +
                    if self.position_encoding_2d:
         
     | 
| 933 | 
         
            +
                        seq_length = seq.index(150004)
         
     | 
| 934 | 
         
            +
                        position_ids = torch.arange(context_length, dtype=torch.long, device=device)
         
     | 
| 935 | 
         
            +
                        if not gmask:
         
     | 
| 936 | 
         
            +
                            position_ids[seq_length:] = mask_position
         
     | 
| 937 | 
         
            +
                        block_position_ids = torch.cat((
         
     | 
| 938 | 
         
            +
                            torch.zeros(seq_length, dtype=torch.long, device=device),
         
     | 
| 939 | 
         
            +
                            torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
         
     | 
| 940 | 
         
            +
                        ))
         
     | 
| 941 | 
         
            +
                        position_ids = torch.stack((position_ids, block_position_ids), dim=0)
         
     | 
| 942 | 
         
            +
                    else:
         
     | 
| 943 | 
         
            +
                        position_ids = torch.arange(context_length, dtype=torch.long, device=device)
         
     | 
| 944 | 
         
            +
                        if not gmask:
         
     | 
| 945 | 
         
            +
                            position_ids[context_length - 1:] = mask_position
         
     | 
| 946 | 
         
            +
             
     | 
| 947 | 
         
            +
                    position_ids = position_ids.unsqueeze(0)
         
     | 
| 948 | 
         
            +
             
     | 
| 949 | 
         
            +
                    return attention_mask, position_ids
         
     | 
| 950 | 
         
            +
             
     | 
| 951 | 
         
            +
                def prepare_inputs_for_generation(
         
     | 
| 952 | 
         
            +
                        self,
         
     | 
| 953 | 
         
            +
                        input_ids: torch.LongTensor,
         
     | 
| 954 | 
         
            +
                        past: Optional[torch.Tensor] = None,
         
     | 
| 955 | 
         
            +
                        attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 956 | 
         
            +
                        **kwargs
         
     | 
| 957 | 
         
            +
                ) -> dict:
         
     | 
| 958 | 
         
            +
             
     | 
| 959 | 
         
            +
                    MASK, gMASK = 150000, 150001
         
     | 
| 960 | 
         
            +
                    mask_token = MASK if MASK in input_ids else gMASK
         
     | 
| 961 | 
         
            +
                    use_gmask = False if MASK in input_ids else gMASK
         
     | 
| 962 | 
         
            +
                    seq = input_ids[0].tolist()
         
     | 
| 963 | 
         
            +
                    mask_position = seq.index(mask_token)
         
     | 
| 964 | 
         
            +
             
     | 
| 965 | 
         
            +
                    if mask_token not in seq:
         
     | 
| 966 | 
         
            +
                        raise ValueError("You have to add either [MASK] or [gMASK] in your input")
         
     | 
| 967 | 
         
            +
             
     | 
| 968 | 
         
            +
                    # only last token for input_ids if past is not None
         
     | 
| 969 | 
         
            +
                    if past:
         
     | 
| 970 | 
         
            +
                        context_length = seq.index(150004)
         
     | 
| 971 | 
         
            +
                        last_token = input_ids[:, -1].unsqueeze(-1)
         
     | 
| 972 | 
         
            +
                        if self.position_encoding_2d:
         
     | 
| 973 | 
         
            +
                            position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
         
     | 
| 974 | 
         
            +
                                                        device=input_ids.device)
         
     | 
| 975 | 
         
            +
                        else:
         
     | 
| 976 | 
         
            +
                            position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
         
     | 
| 977 | 
         
            +
             
     | 
| 978 | 
         
            +
                        return {
         
     | 
| 979 | 
         
            +
                            "input_ids": last_token,
         
     | 
| 980 | 
         
            +
                            "past_key_values": past,
         
     | 
| 981 | 
         
            +
                            "position_ids": position_ids,
         
     | 
| 982 | 
         
            +
                        }
         
     | 
| 983 | 
         
            +
                    else:
         
     | 
| 984 | 
         
            +
                        attention_mask, position_ids = self.get_masks_and_position_ids(
         
     | 
| 985 | 
         
            +
                            seq=seq,
         
     | 
| 986 | 
         
            +
                            mask_position=mask_position,
         
     | 
| 987 | 
         
            +
                            context_length=len(seq),
         
     | 
| 988 | 
         
            +
                            device=input_ids.device,
         
     | 
| 989 | 
         
            +
                            gmask=use_gmask
         
     | 
| 990 | 
         
            +
                        )
         
     | 
| 991 | 
         
            +
             
     | 
| 992 | 
         
            +
                        return {
         
     | 
| 993 | 
         
            +
                            "input_ids": input_ids,
         
     | 
| 994 | 
         
            +
                            "past_key_values": past,
         
     | 
| 995 | 
         
            +
                            "position_ids": position_ids,
         
     | 
| 996 | 
         
            +
                            "attention_mask": attention_mask
         
     | 
| 997 | 
         
            +
                        }
         
     | 
| 998 | 
         
            +
             
     | 
| 999 | 
         
            +
                def forward(
         
     | 
| 1000 | 
         
            +
                        self,
         
     | 
| 1001 | 
         
            +
                        input_ids: Optional[torch.Tensor] = None,
         
     | 
| 1002 | 
         
            +
                        position_ids: Optional[torch.Tensor] = None,
         
     | 
| 1003 | 
         
            +
                        attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 1004 | 
         
            +
                        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
         
     | 
| 1005 | 
         
            +
                        inputs_embeds: Optional[torch.Tensor] = None,
         
     | 
| 1006 | 
         
            +
                        labels: Optional[torch.Tensor] = None,
         
     | 
| 1007 | 
         
            +
                        use_cache: Optional[bool] = None,
         
     | 
| 1008 | 
         
            +
                        output_attentions: Optional[bool] = None,
         
     | 
| 1009 | 
         
            +
                        output_hidden_states: Optional[bool] = None,
         
     | 
| 1010 | 
         
            +
                        return_dict: Optional[bool] = None,
         
     | 
| 1011 | 
         
            +
                ):
         
     | 
| 1012 | 
         
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         
     | 
| 1013 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 1014 | 
         
            +
             
     | 
| 1015 | 
         
            +
                    transformer_outputs = self.transformer(
         
     | 
| 1016 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 1017 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 1018 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 1019 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 1020 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 1021 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 1022 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 1023 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 1024 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 1025 | 
         
            +
                    )
         
     | 
| 1026 | 
         
            +
             
     | 
| 1027 | 
         
            +
                    hidden_states = transformer_outputs[0]
         
     | 
| 1028 | 
         
            +
             
     | 
| 1029 | 
         
            +
                    lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
         
     | 
| 1030 | 
         
            +
             
     | 
| 1031 | 
         
            +
                    loss = None
         
     | 
| 1032 | 
         
            +
                    if labels is not None:
         
     | 
| 1033 | 
         
            +
                        lm_logits = lm_logits.to(torch.float32)
         
     | 
| 1034 | 
         
            +
             
     | 
| 1035 | 
         
            +
                        # Shift so that tokens < n predict n
         
     | 
| 1036 | 
         
            +
                        shift_logits = lm_logits[..., :-1, :].contiguous()
         
     | 
| 1037 | 
         
            +
                        shift_labels = labels[..., 1:].contiguous()
         
     | 
| 1038 | 
         
            +
                        # Flatten the tokens
         
     | 
| 1039 | 
         
            +
                        loss_fct = CrossEntropyLoss()
         
     | 
| 1040 | 
         
            +
                        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
         
     | 
| 1041 | 
         
            +
             
     | 
| 1042 | 
         
            +
                        lm_logits = lm_logits.to(hidden_states.dtype)
         
     | 
| 1043 | 
         
            +
                        loss = loss.to(hidden_states.dtype)
         
     | 
| 1044 | 
         
            +
             
     | 
| 1045 | 
         
            +
                    if not return_dict:
         
     | 
| 1046 | 
         
            +
                        output = (lm_logits,) + transformer_outputs[1:]
         
     | 
| 1047 | 
         
            +
                        return ((loss,) + output) if loss is not None else output
         
     | 
| 1048 | 
         
            +
             
     | 
| 1049 | 
         
            +
                    return CausalLMOutputWithPast(
         
     | 
| 1050 | 
         
            +
                        loss=loss,
         
     | 
| 1051 | 
         
            +
                        logits=lm_logits,
         
     | 
| 1052 | 
         
            +
                        past_key_values=transformer_outputs.past_key_values,
         
     | 
| 1053 | 
         
            +
                        hidden_states=transformer_outputs.hidden_states,
         
     | 
| 1054 | 
         
            +
                        attentions=transformer_outputs.attentions,
         
     | 
| 1055 | 
         
            +
                    )
         
     | 
| 1056 | 
         
            +
             
     | 
| 1057 | 
         
            +
                @staticmethod
         
     | 
| 1058 | 
         
            +
                def _reorder_cache(
         
     | 
| 1059 | 
         
            +
                        past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
         
     | 
| 1060 | 
         
            +
                ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
         
     | 
| 1061 | 
         
            +
                    """
         
     | 
| 1062 | 
         
            +
                    This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
         
     | 
| 1063 | 
         
            +
                    [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
         
     | 
| 1064 | 
         
            +
                    beam_idx at every generation step.
         
     | 
| 1065 | 
         
            +
             
     | 
| 1066 | 
         
            +
                    Output shares the same memory storage as `past`.
         
     | 
| 1067 | 
         
            +
                    """
         
     | 
| 1068 | 
         
            +
                    return tuple(
         
     | 
| 1069 | 
         
            +
                        (
         
     | 
| 1070 | 
         
            +
                            layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
         
     | 
| 1071 | 
         
            +
                            layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
         
     | 
| 1072 | 
         
            +
                        )
         
     | 
| 1073 | 
         
            +
                        for layer_past in past
         
     | 
| 1074 | 
         
            +
                    )
         
     | 
| 1075 | 
         
            +
             
     | 
| 1076 | 
         
            +
                @torch.no_grad()
         
     | 
| 1077 | 
         
            +
                def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], max_length: int = 2048, num_beams=1,
         
     | 
| 1078 | 
         
            +
                         do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
         
     | 
| 1079 | 
         
            +
                    gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
         
     | 
| 1080 | 
         
            +
                                  "temperature": temperature, **kwargs}
         
     | 
| 1081 | 
         
            +
                    if not history:
         
     | 
| 1082 | 
         
            +
                        prompt = query
         
     | 
| 1083 | 
         
            +
                    else:
         
     | 
| 1084 | 
         
            +
                        prompt = ""
         
     | 
| 1085 | 
         
            +
                        for i, (old_query, response) in enumerate(history):
         
     | 
| 1086 | 
         
            +
                            prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
         
     | 
| 1087 | 
         
            +
                        prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
         
     | 
| 1088 | 
         
            +
                    input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
         
     | 
| 1089 | 
         
            +
                    input_ids = input_ids.to(self.device)
         
     | 
| 1090 | 
         
            +
                    outputs = self.generate(**input_ids, **gen_kwargs)
         
     | 
| 1091 | 
         
            +
                    outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
         
     | 
| 1092 | 
         
            +
                    response = tokenizer.decode(outputs)
         
     | 
| 1093 | 
         
            +
                    response = response.strip()
         
     | 
| 1094 | 
         
            +
                    response = response.replace("[[训练时间]]", "2023年")
         
     | 
| 1095 | 
         
            +
                    history.append((query, response))
         
     | 
| 1096 | 
         
            +
                    return response, history
         
     | 
| 1097 | 
         
            +
             
     | 
| 1098 | 
         
            +
                @torch.no_grad()
         
     | 
| 1099 | 
         
            +
                def generate(
         
     | 
| 1100 | 
         
            +
                        self,
         
     | 
| 1101 | 
         
            +
                        **kwargs,
         
     | 
| 1102 | 
         
            +
                ):
         
     | 
| 1103 | 
         
            +
                    MASK, gMASK = 150000, 150001
         
     | 
| 1104 | 
         
            +
                    bos, eos = 150004, 150005
         
     | 
| 1105 | 
         
            +
             
     | 
| 1106 | 
         
            +
                    if "eos_token_id" not in kwargs:
         
     | 
| 1107 | 
         
            +
                        kwargs["eos_token_id"] = eos
         
     | 
| 1108 | 
         
            +
             
     | 
| 1109 | 
         
            +
                    stop = False
         
     | 
| 1110 | 
         
            +
             
     | 
| 1111 | 
         
            +
                    return_seqs = []
         
     | 
| 1112 | 
         
            +
             
     | 
| 1113 | 
         
            +
                    while True:
         
     | 
| 1114 | 
         
            +
                        output_ids = super().generate(**kwargs)
         
     | 
| 1115 | 
         
            +
             
     | 
| 1116 | 
         
            +
                        return_seqs = []
         
     | 
| 1117 | 
         
            +
                        max_length = 0
         
     | 
| 1118 | 
         
            +
             
     | 
| 1119 | 
         
            +
                        for i in range(output_ids.shape[0]):
         
     | 
| 1120 | 
         
            +
                            output_seq = output_ids[i].tolist()
         
     | 
| 1121 | 
         
            +
                            mask_token = MASK if MASK in output_seq else gMASK
         
     | 
| 1122 | 
         
            +
                            mask_position = output_seq.index(mask_token)
         
     | 
| 1123 | 
         
            +
                            bos_position = output_seq.index(bos)
         
     | 
| 1124 | 
         
            +
                            if eos in output_seq:
         
     | 
| 1125 | 
         
            +
                                eos_position = output_seq.index(eos)
         
     | 
| 1126 | 
         
            +
                            else:
         
     | 
| 1127 | 
         
            +
                                eos_position = len(output_seq)
         
     | 
| 1128 | 
         
            +
             
     | 
| 1129 | 
         
            +
                            return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
         
     | 
| 1130 | 
         
            +
                                                                                                                  mask_position + 1:bos_position]
         
     | 
| 1131 | 
         
            +
                            max_length = max(max_length, len(return_seq))
         
     | 
| 1132 | 
         
            +
                            return_seqs.append(return_seq)
         
     | 
| 1133 | 
         
            +
             
     | 
| 1134 | 
         
            +
                        for i in range(output_ids.shape[0]):
         
     | 
| 1135 | 
         
            +
                            return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i]  # padding
         
     | 
| 1136 | 
         
            +
                            if mask_token not in return_seqs[i]:
         
     | 
| 1137 | 
         
            +
                                stop = True
         
     | 
| 1138 | 
         
            +
             
     | 
| 1139 | 
         
            +
                        if stop:
         
     | 
| 1140 | 
         
            +
                            break
         
     | 
| 1141 | 
         
            +
             
     | 
| 1142 | 
         
            +
                        for return_seq in return_seqs:
         
     | 
| 1143 | 
         
            +
                            return_seq += [bos]
         
     | 
| 1144 | 
         
            +
             
     | 
| 1145 | 
         
            +
                        kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
         
     | 
| 1146 | 
         
            +
             
     | 
| 1147 | 
         
            +
                    return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
         
     | 
| 1148 | 
         
            +
             
     | 
| 1149 | 
         
            +
                def quantize(self, bits: int):
         
     | 
| 1150 | 
         
            +
                    from .quantization import quantize
         
     | 
| 1151 | 
         
            +
                    self.transformer = quantize(self.transformer, bits)
         
     | 
| 1152 | 
         
            +
                    return self
         
     | 
    	
        pytorch_model-00001-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00001-of-00008.bin
         
     | 
    	
        pytorch_model-00002-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00002-of-00008.bin
         
     | 
    	
        pytorch_model-00003-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00003-of-00008.bin
         
     | 
    	
        pytorch_model-00004-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00004-of-00008.bin
         
     | 
    	
        pytorch_model-00005-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00005-of-00008.bin
         
     | 
    	
        pytorch_model-00006-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00006-of-00008.bin
         
     | 
    	
        pytorch_model-00007-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00007-of-00008.bin
         
     | 
    	
        pytorch_model-00008-of-00008.bin
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /mnt/vepfs/zxdu/checkpoints/qa-glm-6b-sft-v0.8-v2-original-lr/pytorch_model-00008-of-00008.bin
         
     | 
    	
        pytorch_model.bin.index.json
    ADDED
    
    | 
         @@ -0,0 +1,375 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "metadata": {
         
     | 
| 3 | 
         
            +
                "total_size": 13744473856
         
     | 
| 4 | 
         
            +
              },
         
     | 
| 5 | 
         
            +
              "weight_map": {
         
     | 
| 6 | 
         
            +
                "lm_head.weight": "pytorch_model-00008-of-00008.bin",
         
     | 
| 7 | 
         
            +
                "transformer.final_layernorm.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 8 | 
         
            +
                "transformer.final_layernorm.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 9 | 
         
            +
                "transformer.layers.0.attention.dense.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 10 | 
         
            +
                "transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 11 | 
         
            +
                "transformer.layers.0.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 12 | 
         
            +
                "transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 13 | 
         
            +
                "transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin",
         
     | 
| 14 | 
         
            +
                "transformer.layers.0.input_layernorm.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 15 | 
         
            +
                "transformer.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 16 | 
         
            +
                "transformer.layers.0.mlp.dense_4h_to_h.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 17 | 
         
            +
                "transformer.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 18 | 
         
            +
                "transformer.layers.0.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 19 | 
         
            +
                "transformer.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 20 | 
         
            +
                "transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 21 | 
         
            +
                "transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 22 | 
         
            +
                "transformer.layers.1.attention.dense.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 23 | 
         
            +
                "transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 24 | 
         
            +
                "transformer.layers.1.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 25 | 
         
            +
                "transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 26 | 
         
            +
                "transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin",
         
     | 
| 27 | 
         
            +
                "transformer.layers.1.input_layernorm.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 28 | 
         
            +
                "transformer.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 29 | 
         
            +
                "transformer.layers.1.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 30 | 
         
            +
                "transformer.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 31 | 
         
            +
                "transformer.layers.1.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 32 | 
         
            +
                "transformer.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 33 | 
         
            +
                "transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin",
         
     | 
| 34 | 
         
            +
                "transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin",
         
     | 
| 35 | 
         
            +
                "transformer.layers.10.attention.dense.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 36 | 
         
            +
                "transformer.layers.10.attention.dense.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 37 | 
         
            +
                "transformer.layers.10.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 38 | 
         
            +
                "transformer.layers.10.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 39 | 
         
            +
                "transformer.layers.10.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin",
         
     | 
| 40 | 
         
            +
                "transformer.layers.10.input_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 41 | 
         
            +
                "transformer.layers.10.input_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 42 | 
         
            +
                "transformer.layers.10.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 43 | 
         
            +
                "transformer.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 44 | 
         
            +
                "transformer.layers.10.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 45 | 
         
            +
                "transformer.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 46 | 
         
            +
                "transformer.layers.10.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 47 | 
         
            +
                "transformer.layers.10.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 48 | 
         
            +
                "transformer.layers.11.attention.dense.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 49 | 
         
            +
                "transformer.layers.11.attention.dense.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 50 | 
         
            +
                "transformer.layers.11.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 51 | 
         
            +
                "transformer.layers.11.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 52 | 
         
            +
                "transformer.layers.11.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin",
         
     | 
| 53 | 
         
            +
                "transformer.layers.11.input_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 54 | 
         
            +
                "transformer.layers.11.input_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 55 | 
         
            +
                "transformer.layers.11.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 56 | 
         
            +
                "transformer.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 57 | 
         
            +
                "transformer.layers.11.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 58 | 
         
            +
                "transformer.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 59 | 
         
            +
                "transformer.layers.11.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 60 | 
         
            +
                "transformer.layers.11.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 61 | 
         
            +
                "transformer.layers.12.attention.dense.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 62 | 
         
            +
                "transformer.layers.12.attention.dense.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 63 | 
         
            +
                "transformer.layers.12.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 64 | 
         
            +
                "transformer.layers.12.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 65 | 
         
            +
                "transformer.layers.12.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin",
         
     | 
| 66 | 
         
            +
                "transformer.layers.12.input_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 67 | 
         
            +
                "transformer.layers.12.input_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 68 | 
         
            +
                "transformer.layers.12.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 69 | 
         
            +
                "transformer.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 70 | 
         
            +
                "transformer.layers.12.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 71 | 
         
            +
                "transformer.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 72 | 
         
            +
                "transformer.layers.12.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 73 | 
         
            +
                "transformer.layers.12.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 74 | 
         
            +
                "transformer.layers.13.attention.dense.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 75 | 
         
            +
                "transformer.layers.13.attention.dense.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 76 | 
         
            +
                "transformer.layers.13.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 77 | 
         
            +
                "transformer.layers.13.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 78 | 
         
            +
                "transformer.layers.13.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin",
         
     | 
| 79 | 
         
            +
                "transformer.layers.13.input_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 80 | 
         
            +
                "transformer.layers.13.input_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 81 | 
         
            +
                "transformer.layers.13.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 82 | 
         
            +
                "transformer.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 83 | 
         
            +
                "transformer.layers.13.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 84 | 
         
            +
                "transformer.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 85 | 
         
            +
                "transformer.layers.13.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 86 | 
         
            +
                "transformer.layers.13.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 87 | 
         
            +
                "transformer.layers.14.attention.dense.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 88 | 
         
            +
                "transformer.layers.14.attention.dense.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 89 | 
         
            +
                "transformer.layers.14.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 90 | 
         
            +
                "transformer.layers.14.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 91 | 
         
            +
                "transformer.layers.14.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin",
         
     | 
| 92 | 
         
            +
                "transformer.layers.14.input_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 93 | 
         
            +
                "transformer.layers.14.input_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 94 | 
         
            +
                "transformer.layers.14.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 95 | 
         
            +
                "transformer.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 96 | 
         
            +
                "transformer.layers.14.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 97 | 
         
            +
                "transformer.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 98 | 
         
            +
                "transformer.layers.14.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 99 | 
         
            +
                "transformer.layers.14.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 100 | 
         
            +
                "transformer.layers.15.attention.dense.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 101 | 
         
            +
                "transformer.layers.15.attention.dense.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 102 | 
         
            +
                "transformer.layers.15.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 103 | 
         
            +
                "transformer.layers.15.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 104 | 
         
            +
                "transformer.layers.15.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin",
         
     | 
| 105 | 
         
            +
                "transformer.layers.15.input_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 106 | 
         
            +
                "transformer.layers.15.input_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 107 | 
         
            +
                "transformer.layers.15.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 108 | 
         
            +
                "transformer.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 109 | 
         
            +
                "transformer.layers.15.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 110 | 
         
            +
                "transformer.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 111 | 
         
            +
                "transformer.layers.15.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 112 | 
         
            +
                "transformer.layers.15.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 113 | 
         
            +
                "transformer.layers.16.attention.dense.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 114 | 
         
            +
                "transformer.layers.16.attention.dense.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 115 | 
         
            +
                "transformer.layers.16.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 116 | 
         
            +
                "transformer.layers.16.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 117 | 
         
            +
                "transformer.layers.16.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin",
         
     | 
| 118 | 
         
            +
                "transformer.layers.16.input_layernorm.bias": "pytorch_model-00004-of-00008.bin",
         
     | 
| 119 | 
         
            +
                "transformer.layers.16.input_layernorm.weight": "pytorch_model-00004-of-00008.bin",
         
     | 
| 120 | 
         
            +
                "transformer.layers.16.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 121 | 
         
            +
                "transformer.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 122 | 
         
            +
                "transformer.layers.16.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 123 | 
         
            +
                "transformer.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 124 | 
         
            +
                "transformer.layers.16.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 125 | 
         
            +
                "transformer.layers.16.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 126 | 
         
            +
                "transformer.layers.17.attention.dense.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 127 | 
         
            +
                "transformer.layers.17.attention.dense.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 128 | 
         
            +
                "transformer.layers.17.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 129 | 
         
            +
                "transformer.layers.17.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 130 | 
         
            +
                "transformer.layers.17.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin",
         
     | 
| 131 | 
         
            +
                "transformer.layers.17.input_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 132 | 
         
            +
                "transformer.layers.17.input_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 133 | 
         
            +
                "transformer.layers.17.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 134 | 
         
            +
                "transformer.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 135 | 
         
            +
                "transformer.layers.17.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 136 | 
         
            +
                "transformer.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 137 | 
         
            +
                "transformer.layers.17.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 138 | 
         
            +
                "transformer.layers.17.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 139 | 
         
            +
                "transformer.layers.18.attention.dense.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 140 | 
         
            +
                "transformer.layers.18.attention.dense.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 141 | 
         
            +
                "transformer.layers.18.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 142 | 
         
            +
                "transformer.layers.18.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 143 | 
         
            +
                "transformer.layers.18.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin",
         
     | 
| 144 | 
         
            +
                "transformer.layers.18.input_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 145 | 
         
            +
                "transformer.layers.18.input_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 146 | 
         
            +
                "transformer.layers.18.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 147 | 
         
            +
                "transformer.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 148 | 
         
            +
                "transformer.layers.18.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 149 | 
         
            +
                "transformer.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 150 | 
         
            +
                "transformer.layers.18.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 151 | 
         
            +
                "transformer.layers.18.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 152 | 
         
            +
                "transformer.layers.19.attention.dense.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 153 | 
         
            +
                "transformer.layers.19.attention.dense.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 154 | 
         
            +
                "transformer.layers.19.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 155 | 
         
            +
                "transformer.layers.19.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 156 | 
         
            +
                "transformer.layers.19.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin",
         
     | 
| 157 | 
         
            +
                "transformer.layers.19.input_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 158 | 
         
            +
                "transformer.layers.19.input_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 159 | 
         
            +
                "transformer.layers.19.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 160 | 
         
            +
                "transformer.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 161 | 
         
            +
                "transformer.layers.19.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 162 | 
         
            +
                "transformer.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 163 | 
         
            +
                "transformer.layers.19.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 164 | 
         
            +
                "transformer.layers.19.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 165 | 
         
            +
                "transformer.layers.2.attention.dense.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 166 | 
         
            +
                "transformer.layers.2.attention.dense.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 167 | 
         
            +
                "transformer.layers.2.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 168 | 
         
            +
                "transformer.layers.2.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 169 | 
         
            +
                "transformer.layers.2.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin",
         
     | 
| 170 | 
         
            +
                "transformer.layers.2.input_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 171 | 
         
            +
                "transformer.layers.2.input_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 172 | 
         
            +
                "transformer.layers.2.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 173 | 
         
            +
                "transformer.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 174 | 
         
            +
                "transformer.layers.2.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 175 | 
         
            +
                "transformer.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 176 | 
         
            +
                "transformer.layers.2.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 177 | 
         
            +
                "transformer.layers.2.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 178 | 
         
            +
                "transformer.layers.20.attention.dense.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 179 | 
         
            +
                "transformer.layers.20.attention.dense.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 180 | 
         
            +
                "transformer.layers.20.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 181 | 
         
            +
                "transformer.layers.20.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 182 | 
         
            +
                "transformer.layers.20.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin",
         
     | 
| 183 | 
         
            +
                "transformer.layers.20.input_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 184 | 
         
            +
                "transformer.layers.20.input_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 185 | 
         
            +
                "transformer.layers.20.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 186 | 
         
            +
                "transformer.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 187 | 
         
            +
                "transformer.layers.20.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 188 | 
         
            +
                "transformer.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 189 | 
         
            +
                "transformer.layers.20.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin",
         
     | 
| 190 | 
         
            +
                "transformer.layers.20.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin",
         
     | 
| 191 | 
         
            +
                "transformer.layers.21.attention.dense.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 192 | 
         
            +
                "transformer.layers.21.attention.dense.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 193 | 
         
            +
                "transformer.layers.21.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 194 | 
         
            +
                "transformer.layers.21.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 195 | 
         
            +
                "transformer.layers.21.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin",
         
     | 
| 196 | 
         
            +
                "transformer.layers.21.input_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 197 | 
         
            +
                "transformer.layers.21.input_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 198 | 
         
            +
                "transformer.layers.21.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 199 | 
         
            +
                "transformer.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 200 | 
         
            +
                "transformer.layers.21.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 201 | 
         
            +
                "transformer.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 202 | 
         
            +
                "transformer.layers.21.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 203 | 
         
            +
                "transformer.layers.21.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 204 | 
         
            +
                "transformer.layers.22.attention.dense.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 205 | 
         
            +
                "transformer.layers.22.attention.dense.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 206 | 
         
            +
                "transformer.layers.22.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 207 | 
         
            +
                "transformer.layers.22.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 208 | 
         
            +
                "transformer.layers.22.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin",
         
     | 
| 209 | 
         
            +
                "transformer.layers.22.input_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 210 | 
         
            +
                "transformer.layers.22.input_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 211 | 
         
            +
                "transformer.layers.22.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 212 | 
         
            +
                "transformer.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 213 | 
         
            +
                "transformer.layers.22.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 214 | 
         
            +
                "transformer.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 215 | 
         
            +
                "transformer.layers.22.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 216 | 
         
            +
                "transformer.layers.22.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 217 | 
         
            +
                "transformer.layers.23.attention.dense.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 218 | 
         
            +
                "transformer.layers.23.attention.dense.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 219 | 
         
            +
                "transformer.layers.23.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 220 | 
         
            +
                "transformer.layers.23.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 221 | 
         
            +
                "transformer.layers.23.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin",
         
     | 
| 222 | 
         
            +
                "transformer.layers.23.input_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 223 | 
         
            +
                "transformer.layers.23.input_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 224 | 
         
            +
                "transformer.layers.23.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 225 | 
         
            +
                "transformer.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 226 | 
         
            +
                "transformer.layers.23.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 227 | 
         
            +
                "transformer.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 228 | 
         
            +
                "transformer.layers.23.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 229 | 
         
            +
                "transformer.layers.23.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 230 | 
         
            +
                "transformer.layers.24.attention.dense.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 231 | 
         
            +
                "transformer.layers.24.attention.dense.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 232 | 
         
            +
                "transformer.layers.24.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 233 | 
         
            +
                "transformer.layers.24.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 234 | 
         
            +
                "transformer.layers.24.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin",
         
     | 
| 235 | 
         
            +
                "transformer.layers.24.input_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 236 | 
         
            +
                "transformer.layers.24.input_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 237 | 
         
            +
                "transformer.layers.24.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 238 | 
         
            +
                "transformer.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 239 | 
         
            +
                "transformer.layers.24.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 240 | 
         
            +
                "transformer.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 241 | 
         
            +
                "transformer.layers.24.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 242 | 
         
            +
                "transformer.layers.24.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 243 | 
         
            +
                "transformer.layers.25.attention.dense.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 244 | 
         
            +
                "transformer.layers.25.attention.dense.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 245 | 
         
            +
                "transformer.layers.25.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 246 | 
         
            +
                "transformer.layers.25.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 247 | 
         
            +
                "transformer.layers.25.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin",
         
     | 
| 248 | 
         
            +
                "transformer.layers.25.input_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 249 | 
         
            +
                "transformer.layers.25.input_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 250 | 
         
            +
                "transformer.layers.25.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 251 | 
         
            +
                "transformer.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 252 | 
         
            +
                "transformer.layers.25.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 253 | 
         
            +
                "transformer.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 254 | 
         
            +
                "transformer.layers.25.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin",
         
     | 
| 255 | 
         
            +
                "transformer.layers.25.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin",
         
     | 
| 256 | 
         
            +
                "transformer.layers.26.attention.dense.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 257 | 
         
            +
                "transformer.layers.26.attention.dense.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 258 | 
         
            +
                "transformer.layers.26.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 259 | 
         
            +
                "transformer.layers.26.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 260 | 
         
            +
                "transformer.layers.26.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin",
         
     | 
| 261 | 
         
            +
                "transformer.layers.26.input_layernorm.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 262 | 
         
            +
                "transformer.layers.26.input_layernorm.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 263 | 
         
            +
                "transformer.layers.26.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 264 | 
         
            +
                "transformer.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 265 | 
         
            +
                "transformer.layers.26.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 266 | 
         
            +
                "transformer.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 267 | 
         
            +
                "transformer.layers.26.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 268 | 
         
            +
                "transformer.layers.26.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 269 | 
         
            +
                "transformer.layers.27.attention.dense.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 270 | 
         
            +
                "transformer.layers.27.attention.dense.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 271 | 
         
            +
                "transformer.layers.27.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 272 | 
         
            +
                "transformer.layers.27.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 273 | 
         
            +
                "transformer.layers.27.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin",
         
     | 
| 274 | 
         
            +
                "transformer.layers.27.input_layernorm.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 275 | 
         
            +
                "transformer.layers.27.input_layernorm.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 276 | 
         
            +
                "transformer.layers.27.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 277 | 
         
            +
                "transformer.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 278 | 
         
            +
                "transformer.layers.27.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 279 | 
         
            +
                "transformer.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 280 | 
         
            +
                "transformer.layers.27.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin",
         
     | 
| 281 | 
         
            +
                "transformer.layers.27.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin",
         
     | 
| 282 | 
         
            +
                "transformer.layers.3.attention.dense.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 283 | 
         
            +
                "transformer.layers.3.attention.dense.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 284 | 
         
            +
                "transformer.layers.3.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 285 | 
         
            +
                "transformer.layers.3.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 286 | 
         
            +
                "transformer.layers.3.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin",
         
     | 
| 287 | 
         
            +
                "transformer.layers.3.input_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 288 | 
         
            +
                "transformer.layers.3.input_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 289 | 
         
            +
                "transformer.layers.3.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 290 | 
         
            +
                "transformer.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 291 | 
         
            +
                "transformer.layers.3.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 292 | 
         
            +
                "transformer.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 293 | 
         
            +
                "transformer.layers.3.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 294 | 
         
            +
                "transformer.layers.3.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 295 | 
         
            +
                "transformer.layers.4.attention.dense.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 296 | 
         
            +
                "transformer.layers.4.attention.dense.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 297 | 
         
            +
                "transformer.layers.4.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 298 | 
         
            +
                "transformer.layers.4.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 299 | 
         
            +
                "transformer.layers.4.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin",
         
     | 
| 300 | 
         
            +
                "transformer.layers.4.input_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 301 | 
         
            +
                "transformer.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 302 | 
         
            +
                "transformer.layers.4.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 303 | 
         
            +
                "transformer.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 304 | 
         
            +
                "transformer.layers.4.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 305 | 
         
            +
                "transformer.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 306 | 
         
            +
                "transformer.layers.4.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 307 | 
         
            +
                "transformer.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 308 | 
         
            +
                "transformer.layers.5.attention.dense.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 309 | 
         
            +
                "transformer.layers.5.attention.dense.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 310 | 
         
            +
                "transformer.layers.5.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 311 | 
         
            +
                "transformer.layers.5.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 312 | 
         
            +
                "transformer.layers.5.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin",
         
     | 
| 313 | 
         
            +
                "transformer.layers.5.input_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 314 | 
         
            +
                "transformer.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 315 | 
         
            +
                "transformer.layers.5.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 316 | 
         
            +
                "transformer.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 317 | 
         
            +
                "transformer.layers.5.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 318 | 
         
            +
                "transformer.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 319 | 
         
            +
                "transformer.layers.5.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 320 | 
         
            +
                "transformer.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 321 | 
         
            +
                "transformer.layers.6.attention.dense.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 322 | 
         
            +
                "transformer.layers.6.attention.dense.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 323 | 
         
            +
                "transformer.layers.6.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 324 | 
         
            +
                "transformer.layers.6.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 325 | 
         
            +
                "transformer.layers.6.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin",
         
     | 
| 326 | 
         
            +
                "transformer.layers.6.input_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 327 | 
         
            +
                "transformer.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 328 | 
         
            +
                "transformer.layers.6.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 329 | 
         
            +
                "transformer.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 330 | 
         
            +
                "transformer.layers.6.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 331 | 
         
            +
                "transformer.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 332 | 
         
            +
                "transformer.layers.6.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin",
         
     | 
| 333 | 
         
            +
                "transformer.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin",
         
     | 
| 334 | 
         
            +
                "transformer.layers.7.attention.dense.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 335 | 
         
            +
                "transformer.layers.7.attention.dense.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 336 | 
         
            +
                "transformer.layers.7.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 337 | 
         
            +
                "transformer.layers.7.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 338 | 
         
            +
                "transformer.layers.7.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin",
         
     | 
| 339 | 
         
            +
                "transformer.layers.7.input_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 340 | 
         
            +
                "transformer.layers.7.input_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 341 | 
         
            +
                "transformer.layers.7.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 342 | 
         
            +
                "transformer.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 343 | 
         
            +
                "transformer.layers.7.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 344 | 
         
            +
                "transformer.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 345 | 
         
            +
                "transformer.layers.7.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 346 | 
         
            +
                "transformer.layers.7.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 347 | 
         
            +
                "transformer.layers.8.attention.dense.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 348 | 
         
            +
                "transformer.layers.8.attention.dense.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 349 | 
         
            +
                "transformer.layers.8.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 350 | 
         
            +
                "transformer.layers.8.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 351 | 
         
            +
                "transformer.layers.8.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin",
         
     | 
| 352 | 
         
            +
                "transformer.layers.8.input_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 353 | 
         
            +
                "transformer.layers.8.input_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 354 | 
         
            +
                "transformer.layers.8.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 355 | 
         
            +
                "transformer.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 356 | 
         
            +
                "transformer.layers.8.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 357 | 
         
            +
                "transformer.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 358 | 
         
            +
                "transformer.layers.8.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 359 | 
         
            +
                "transformer.layers.8.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 360 | 
         
            +
                "transformer.layers.9.attention.dense.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 361 | 
         
            +
                "transformer.layers.9.attention.dense.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 362 | 
         
            +
                "transformer.layers.9.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 363 | 
         
            +
                "transformer.layers.9.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 364 | 
         
            +
                "transformer.layers.9.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin",
         
     | 
| 365 | 
         
            +
                "transformer.layers.9.input_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 366 | 
         
            +
                "transformer.layers.9.input_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 367 | 
         
            +
                "transformer.layers.9.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 368 | 
         
            +
                "transformer.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 369 | 
         
            +
                "transformer.layers.9.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 370 | 
         
            +
                "transformer.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 371 | 
         
            +
                "transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin",
         
     | 
| 372 | 
         
            +
                "transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin",
         
     | 
| 373 | 
         
            +
                "transformer.word_embeddings.weight": "pytorch_model-00001-of-00008.bin"
         
     | 
| 374 | 
         
            +
              }
         
     | 
| 375 | 
         
            +
            }
         
     | 
    	
        quantization.py
    ADDED
    
    | 
         @@ -0,0 +1,187 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from torch.nn import Linear
         
     | 
| 2 | 
         
            +
            from torch.nn.parameter import Parameter
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import bz2
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import base64
         
     | 
| 7 | 
         
            +
            import ctypes
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from typing import List
         
     | 
| 10 | 
         
            +
            from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class W8A16Linear(torch.autograd.Function):
         
     | 
| 14 | 
         
            +
                @staticmethod
         
     | 
| 15 | 
         
            +
                def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
         
     | 
| 16 | 
         
            +
                    ctx.inp_shape = inp.size()
         
     | 
| 17 | 
         
            +
                    ctx.weight_shape = quant_w.size()
         
     | 
| 18 | 
         
            +
                    ctx.weight_bit_width = weight_bit_width
         
     | 
| 19 | 
         
            +
                    out_features = quant_w.size(0)
         
     | 
| 20 | 
         
            +
                    inp = inp.contiguous().view(-1, inp.size(-1))
         
     | 
| 21 | 
         
            +
                    weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
         
     | 
| 22 | 
         
            +
                    output = inp.mm(weight.t())
         
     | 
| 23 | 
         
            +
                    ctx.save_for_backward(inp, quant_w, scale_w)
         
     | 
| 24 | 
         
            +
                    return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                @staticmethod
         
     | 
| 27 | 
         
            +
                def backward(ctx, grad_output: torch.Tensor):
         
     | 
| 28 | 
         
            +
                    inp, quant_w, scale_w = ctx.saved_tensors
         
     | 
| 29 | 
         
            +
                    weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
         
     | 
| 30 | 
         
            +
                    grad_output = grad_output.contiguous().view(-1, weight.size(0))
         
     | 
| 31 | 
         
            +
                    grad_input = grad_output.mm(weight)
         
     | 
| 32 | 
         
            +
                    grad_weight = grad_output.t().mm(inp)
         
     | 
| 33 | 
         
            +
                    return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class Kernel:
         
     | 
| 37 | 
         
            +
                def __init__(self, code: bytes, function_names: List[str]):
         
     | 
| 38 | 
         
            +
                    self.code = code
         
     | 
| 39 | 
         
            +
                    self._function_names = function_names
         
     | 
| 40 | 
         
            +
                    self._cmodule = LazyKernelCModule(self.code)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    for name in self._function_names:
         
     | 
| 43 | 
         
            +
                        setattr(self, name, KernelFunction(self._cmodule, name))
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            kernels = Kernel(
         
     | 
| 49 | 
         
            +
                bz2.decompress(base64.b64decode(quantization_code)),
         
     | 
| 50 | 
         
            +
                [
         
     | 
| 51 | 
         
            +
                    "int4WeightCompression",
         
     | 
| 52 | 
         
            +
                    "int4WeightExtractionFloat",
         
     | 
| 53 | 
         
            +
                    "int4WeightExtractionHalf",
         
     | 
| 54 | 
         
            +
                    "int8WeightExtractionFloat",
         
     | 
| 55 | 
         
            +
                    "int8WeightExtractionHalf",
         
     | 
| 56 | 
         
            +
                ],
         
     | 
| 57 | 
         
            +
            )
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            def compress_int4_weight(weight: torch.Tensor):  # (n, m)
         
     | 
| 61 | 
         
            +
                with torch.cuda.device(weight.device):
         
     | 
| 62 | 
         
            +
                    n, m = weight.size(0), weight.size(1)
         
     | 
| 63 | 
         
            +
                    assert m % 2 == 0
         
     | 
| 64 | 
         
            +
                    m = m // 2
         
     | 
| 65 | 
         
            +
                    out = torch.empty(n, m, dtype=torch.int8, device="cuda")
         
     | 
| 66 | 
         
            +
                    stream = torch.cuda.current_stream()
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    gridDim = (n, 1, 1)
         
     | 
| 69 | 
         
            +
                    blockDim = (min(round_up(m, 32), 1024), 1, 1)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    kernels.int4WeightCompression(
         
     | 
| 72 | 
         
            +
                        gridDim,
         
     | 
| 73 | 
         
            +
                        blockDim,
         
     | 
| 74 | 
         
            +
                        0,
         
     | 
| 75 | 
         
            +
                        stream,
         
     | 
| 76 | 
         
            +
                        [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    return out
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
         
     | 
| 82 | 
         
            +
                if source_bit_width == 8:
         
     | 
| 83 | 
         
            +
                    func = kernels.int8WeightExtractionHalf
         
     | 
| 84 | 
         
            +
                elif source_bit_width == 4:
         
     | 
| 85 | 
         
            +
                    func = kernels.int4WeightExtractionHalf
         
     | 
| 86 | 
         
            +
                else:
         
     | 
| 87 | 
         
            +
                    assert False, "Unsupported bit-width"
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                with torch.cuda.device(weight.device):
         
     | 
| 90 | 
         
            +
                    n, m = weight.size(0), weight.size(1)
         
     | 
| 91 | 
         
            +
                    out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
         
     | 
| 92 | 
         
            +
                    stream = torch.cuda.current_stream()
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    gridDim = (n, 1, 1)
         
     | 
| 95 | 
         
            +
                    blockDim = (min(round_up(m, 32), 1024), 1, 1)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    func(
         
     | 
| 98 | 
         
            +
                        gridDim,
         
     | 
| 99 | 
         
            +
                        blockDim,
         
     | 
| 100 | 
         
            +
                        0,
         
     | 
| 101 | 
         
            +
                        stream,
         
     | 
| 102 | 
         
            +
                        [
         
     | 
| 103 | 
         
            +
                            ctypes.c_void_p(weight.data_ptr()),
         
     | 
| 104 | 
         
            +
                            ctypes.c_void_p(scale_list.data_ptr()),
         
     | 
| 105 | 
         
            +
                            ctypes.c_void_p(out.data_ptr()),
         
     | 
| 106 | 
         
            +
                            ctypes.c_int32(n),
         
     | 
| 107 | 
         
            +
                            ctypes.c_int32(m),
         
     | 
| 108 | 
         
            +
                        ],
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
                    return out
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            class QuantizedLinear(Linear):
         
     | 
| 114 | 
         
            +
                def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, *args, **kwargs):
         
     | 
| 115 | 
         
            +
                    super(QuantizedLinear, self).__init__(*args, **kwargs)
         
     | 
| 116 | 
         
            +
                    self.weight_bit_width = weight_bit_width
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    shape = self.weight.shape
         
     | 
| 119 | 
         
            +
                    del self.weight
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    if weight_tensor is None:
         
     | 
| 122 | 
         
            +
                        self.weight = torch.empty(
         
     | 
| 123 | 
         
            +
                            shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
         
     | 
| 124 | 
         
            +
                        )
         
     | 
| 125 | 
         
            +
                        self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
         
     | 
| 126 | 
         
            +
                    else:
         
     | 
| 127 | 
         
            +
                        self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
         
     | 
| 128 | 
         
            +
                        self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
         
     | 
| 129 | 
         
            +
                        if weight_bit_width == 4:
         
     | 
| 130 | 
         
            +
                            self.weight = compress_int4_weight(self.weight)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
         
     | 
| 133 | 
         
            +
                    self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
         
     | 
| 134 | 
         
            +
                    self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def forward(self, input):
         
     | 
| 137 | 
         
            +
                    output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
         
     | 
| 138 | 
         
            +
                    if self.bias is not None:
         
     | 
| 139 | 
         
            +
                        output = output + self.bias
         
     | 
| 140 | 
         
            +
                    return output
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            def quantize(model, weight_bit_width):
         
     | 
| 144 | 
         
            +
                """Replace fp16 linear with quantized linear"""
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                for layer in model.layers:
         
     | 
| 147 | 
         
            +
                    layer.attention.query_key_value = QuantizedLinear(
         
     | 
| 148 | 
         
            +
                        weight_bit_width=weight_bit_width,
         
     | 
| 149 | 
         
            +
                        weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
         
     | 
| 150 | 
         
            +
                        bias_tensor=layer.attention.query_key_value.bias,
         
     | 
| 151 | 
         
            +
                        in_features=layer.attention.query_key_value.in_features,
         
     | 
| 152 | 
         
            +
                        out_features=layer.attention.query_key_value.out_features,
         
     | 
| 153 | 
         
            +
                        bias=True,
         
     | 
| 154 | 
         
            +
                        dtype=torch.half,
         
     | 
| 155 | 
         
            +
                        device=layer.attention.query_key_value.weight.device,
         
     | 
| 156 | 
         
            +
                    )
         
     | 
| 157 | 
         
            +
                    layer.attention.dense = QuantizedLinear(
         
     | 
| 158 | 
         
            +
                        weight_bit_width=weight_bit_width,
         
     | 
| 159 | 
         
            +
                        weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()),
         
     | 
| 160 | 
         
            +
                        bias_tensor=layer.attention.dense.bias,
         
     | 
| 161 | 
         
            +
                        in_features=layer.attention.dense.in_features,
         
     | 
| 162 | 
         
            +
                        out_features=layer.attention.dense.out_features,
         
     | 
| 163 | 
         
            +
                        bias=True,
         
     | 
| 164 | 
         
            +
                        dtype=torch.half,
         
     | 
| 165 | 
         
            +
                        device=layer.attention.dense.weight.device,
         
     | 
| 166 | 
         
            +
                    )
         
     | 
| 167 | 
         
            +
                    layer.mlp.dense_h_to_4h = QuantizedLinear(
         
     | 
| 168 | 
         
            +
                        weight_bit_width=weight_bit_width,
         
     | 
| 169 | 
         
            +
                        weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
         
     | 
| 170 | 
         
            +
                        bias_tensor=layer.mlp.dense_h_to_4h.bias,
         
     | 
| 171 | 
         
            +
                        in_features=layer.mlp.dense_h_to_4h.in_features,
         
     | 
| 172 | 
         
            +
                        out_features=layer.mlp.dense_h_to_4h.out_features,
         
     | 
| 173 | 
         
            +
                        bias=True,
         
     | 
| 174 | 
         
            +
                        dtype=torch.half,
         
     | 
| 175 | 
         
            +
                        device=layer.mlp.dense_h_to_4h.weight.device,
         
     | 
| 176 | 
         
            +
                    )
         
     | 
| 177 | 
         
            +
                    layer.mlp.dense_4h_to_h = QuantizedLinear(
         
     | 
| 178 | 
         
            +
                        weight_bit_width=weight_bit_width,
         
     | 
| 179 | 
         
            +
                        weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
         
     | 
| 180 | 
         
            +
                        bias_tensor=layer.mlp.dense_4h_to_h.bias,
         
     | 
| 181 | 
         
            +
                        in_features=layer.mlp.dense_4h_to_h.in_features,
         
     | 
| 182 | 
         
            +
                        out_features=layer.mlp.dense_4h_to_h.out_features,
         
     | 
| 183 | 
         
            +
                        bias=True,
         
     | 
| 184 | 
         
            +
                        dtype=torch.half,
         
     | 
| 185 | 
         
            +
                        device=layer.mlp.dense_4h_to_h.weight.device,
         
     | 
| 186 | 
         
            +
                    )
         
     | 
| 187 | 
         
            +
                return model
         
     | 
    	
        tokenization_chatglm.py
    ADDED
    
    | 
         @@ -0,0 +1,347 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Tokenization classes for ChatGLM."""
         
     | 
| 2 | 
         
            +
            import sys
         
     | 
| 3 | 
         
            +
            import unicodedata
         
     | 
| 4 | 
         
            +
            from typing import List, Optional, Union
         
     | 
| 5 | 
         
            +
            from functools import lru_cache
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import collections
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from transformers.tokenization_utils import PreTrainedTokenizer
         
     | 
| 11 | 
         
            +
            from icetk.text_tokenizer import TextTokenizer
         
     | 
| 12 | 
         
            +
            from icetk.utils import auto_create
         
     | 
| 13 | 
         
            +
            import icetk.sentencepiece_model_pb2 as sp_model
         
     | 
| 14 | 
         
            +
            from transformers.utils import logging
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            VOCAB_FILES_NAMES = {"vocab_file": "ice_text.model"}
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
         
     | 
| 21 | 
         
            +
                "THUDM/chatglm-6b": 2048,
         
     | 
| 22 | 
         
            +
            }
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class SPTokenizer:
         
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self,
         
     | 
| 28 | 
         
            +
                    vocab_file,
         
     | 
| 29 | 
         
            +
                    max_blank_length=80,
         
     | 
| 30 | 
         
            +
                    byte_fallback=True,
         
     | 
| 31 | 
         
            +
                ):
         
     | 
| 32 | 
         
            +
                    assert vocab_file is not None
         
     | 
| 33 | 
         
            +
                    self.vocab_file = vocab_file
         
     | 
| 34 | 
         
            +
                    self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
         
     | 
| 35 | 
         
            +
                    self.max_blank_length = max_blank_length
         
     | 
| 36 | 
         
            +
                    self.byte_fallback = byte_fallback
         
     | 
| 37 | 
         
            +
                    self.text_tokenizer = self._build_text_tokenizer(encode_special_tokens=False)
         
     | 
| 38 | 
         
            +
                    self.special_text_tokenizer = self._build_text_tokenizer(encode_special_tokens=True)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                @staticmethod
         
     | 
| 41 | 
         
            +
                def _configure_tokenizer(
         
     | 
| 42 | 
         
            +
                    text_tokenizer: TextTokenizer,
         
     | 
| 43 | 
         
            +
                    special_tokens: List[str],
         
     | 
| 44 | 
         
            +
                    max_blank_length: int,
         
     | 
| 45 | 
         
            +
                    byte_fallback: bool,
         
     | 
| 46 | 
         
            +
                    encode_special_tokens=False,
         
     | 
| 47 | 
         
            +
                ):
         
     | 
| 48 | 
         
            +
                    # special token
         
     | 
| 49 | 
         
            +
                    special_token_type = 4 if encode_special_tokens else 3  # 3 - CONTROL, 4 - USER_DEFINE
         
     | 
| 50 | 
         
            +
                    for token in special_tokens:
         
     | 
| 51 | 
         
            +
                        text_tokenizer.proto.pieces.append(
         
     | 
| 52 | 
         
            +
                            sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=special_token_type)
         
     | 
| 53 | 
         
            +
                        )
         
     | 
| 54 | 
         
            +
                    # whitespaces
         
     | 
| 55 | 
         
            +
                    for token in [SPTokenizer.get_tab_token()] + [
         
     | 
| 56 | 
         
            +
                        SPTokenizer.get_blank_token(i) for i in range(2, max_blank_length + 1)
         
     | 
| 57 | 
         
            +
                    ]:
         
     | 
| 58 | 
         
            +
                        text_tokenizer.proto.pieces.append(sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=4))
         
     | 
| 59 | 
         
            +
                    # byte fallback
         
     | 
| 60 | 
         
            +
                    if byte_fallback:
         
     | 
| 61 | 
         
            +
                        text_tokenizer.proto.trainer_spec.byte_fallback = True
         
     | 
| 62 | 
         
            +
                        for i in range(256):
         
     | 
| 63 | 
         
            +
                            text_tokenizer.proto.pieces.append(
         
     | 
| 64 | 
         
            +
                                sp_model.ModelProto.SentencePiece(piece="<0x{:02X}>".format(i), score=0.0, type=6)
         
     | 
| 65 | 
         
            +
                            )
         
     | 
| 66 | 
         
            +
                    text_tokenizer.refresh()
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def _build_text_tokenizer(self, encode_special_tokens=False):
         
     | 
| 69 | 
         
            +
                    tokenizer = TextTokenizer(self.vocab_file)
         
     | 
| 70 | 
         
            +
                    self._configure_tokenizer(
         
     | 
| 71 | 
         
            +
                        tokenizer, self.special_tokens, self.max_blank_length, self.byte_fallback, encode_special_tokens
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
                    return tokenizer
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def _get_text_tokenizer(self, encode_special_tokens=False):
         
     | 
| 76 | 
         
            +
                    if encode_special_tokens:
         
     | 
| 77 | 
         
            +
                        return self.special_text_tokenizer
         
     | 
| 78 | 
         
            +
                    else:
         
     | 
| 79 | 
         
            +
                        return self.text_tokenizer
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                @staticmethod
         
     | 
| 82 | 
         
            +
                def get_blank_token(length: int):
         
     | 
| 83 | 
         
            +
                    assert length >= 2
         
     | 
| 84 | 
         
            +
                    return f"<|blank_{length}|>"
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                @staticmethod
         
     | 
| 87 | 
         
            +
                def get_tab_token():
         
     | 
| 88 | 
         
            +
                    return f"<|tab|>"
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                @property
         
     | 
| 91 | 
         
            +
                def num_image_tokens(self):
         
     | 
| 92 | 
         
            +
                    return 20000
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                @property
         
     | 
| 95 | 
         
            +
                def num_text_tokens(self):
         
     | 
| 96 | 
         
            +
                    return self.text_tokenizer.num_tokens
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                @property
         
     | 
| 99 | 
         
            +
                def num_tokens(self):
         
     | 
| 100 | 
         
            +
                    return self.num_image_tokens + self.num_text_tokens
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                @staticmethod
         
     | 
| 103 | 
         
            +
                def _encode_whitespaces(text: str, max_len: int = 80):
         
     | 
| 104 | 
         
            +
                    text = text.replace("\t", SPTokenizer.get_tab_token())
         
     | 
| 105 | 
         
            +
                    for i in range(max_len, 1, -1):
         
     | 
| 106 | 
         
            +
                        text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
         
     | 
| 107 | 
         
            +
                    return text
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def _preprocess(self, text: str, linebreak=True, whitespaces=True):
         
     | 
| 110 | 
         
            +
                    if linebreak:
         
     | 
| 111 | 
         
            +
                        text = text.replace("\n", "<n>")
         
     | 
| 112 | 
         
            +
                    if whitespaces:
         
     | 
| 113 | 
         
            +
                        text = self._encode_whitespaces(text, max_len=self.max_blank_length)
         
     | 
| 114 | 
         
            +
                    return text
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def encode(
         
     | 
| 117 | 
         
            +
                    self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True
         
     | 
| 118 | 
         
            +
                ) -> List[int]:
         
     | 
| 119 | 
         
            +
                    """
         
     | 
| 120 | 
         
            +
                    @param text: Text to encode.
         
     | 
| 121 | 
         
            +
                    @param linebreak: Whether to encode newline (\n) in text.
         
     | 
| 122 | 
         
            +
                    @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
         
     | 
| 123 | 
         
            +
                    @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
         
     | 
| 124 | 
         
            +
                    @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
         
     | 
| 125 | 
         
            +
                    """
         
     | 
| 126 | 
         
            +
                    text = self._preprocess(text, linebreak, whitespaces)
         
     | 
| 127 | 
         
            +
                    if not add_dummy_prefix:
         
     | 
| 128 | 
         
            +
                        text = "<n>" + text
         
     | 
| 129 | 
         
            +
                    tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text)
         
     | 
| 130 | 
         
            +
                    tokens = [x + self.num_image_tokens for x in tmp]
         
     | 
| 131 | 
         
            +
                    return tokens if add_dummy_prefix else tokens[2:]
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def decode(self, text_ids: List[int], special_tokens=False) -> str:
         
     | 
| 134 | 
         
            +
                    ids = [int(_id) - self.num_image_tokens for _id in text_ids]
         
     | 
| 135 | 
         
            +
                    text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids)
         
     | 
| 136 | 
         
            +
                    text = text.replace("<n>", "\n")
         
     | 
| 137 | 
         
            +
                    text = text.replace(SPTokenizer.get_tab_token(), "\t")
         
     | 
| 138 | 
         
            +
                    for i in range(2, self.max_blank_length + 1):
         
     | 
| 139 | 
         
            +
                        text = text.replace(self.get_blank_token(i), " " * i)
         
     | 
| 140 | 
         
            +
                    return text
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def tokenize(
         
     | 
| 143 | 
         
            +
                    self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True
         
     | 
| 144 | 
         
            +
                ) -> List[str]:
         
     | 
| 145 | 
         
            +
                    """
         
     | 
| 146 | 
         
            +
                    @param text: Text to encode.
         
     | 
| 147 | 
         
            +
                    @param linebreak: Whether to encode newline (\n) in text.
         
     | 
| 148 | 
         
            +
                    @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
         
     | 
| 149 | 
         
            +
                    @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
         
     | 
| 150 | 
         
            +
                    @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
         
     | 
| 151 | 
         
            +
                    """
         
     | 
| 152 | 
         
            +
                    text = self._preprocess(text, linebreak, whitespaces)
         
     | 
| 153 | 
         
            +
                    if not add_dummy_prefix:
         
     | 
| 154 | 
         
            +
                        text = "<n>" + text
         
     | 
| 155 | 
         
            +
                    tokens = self._get_text_tokenizer(encode_special_tokens=special_tokens).tokenize(text)
         
     | 
| 156 | 
         
            +
                    return tokens if add_dummy_prefix else tokens[2:]
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                def __getitem__(self, x: Union[int, str]):
         
     | 
| 159 | 
         
            +
                    if isinstance(x, int):
         
     | 
| 160 | 
         
            +
                        if x < self.num_image_tokens:
         
     | 
| 161 | 
         
            +
                            return "<image_{}>".format(x)
         
     | 
| 162 | 
         
            +
                        else:
         
     | 
| 163 | 
         
            +
                            return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
         
     | 
| 164 | 
         
            +
                    elif isinstance(x, str):
         
     | 
| 165 | 
         
            +
                        if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
         
     | 
| 166 | 
         
            +
                            return int(x[7:-1])
         
     | 
| 167 | 
         
            +
                        else:
         
     | 
| 168 | 
         
            +
                            return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
         
     | 
| 169 | 
         
            +
                    else:
         
     | 
| 170 | 
         
            +
                        raise ValueError("The key should be str or int.")
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            class ChatGLMTokenizer(PreTrainedTokenizer):
         
     | 
| 174 | 
         
            +
                """
         
     | 
| 175 | 
         
            +
                Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                Args:
         
     | 
| 178 | 
         
            +
                    vocab_file (`str`):
         
     | 
| 179 | 
         
            +
                        Path to the vocabulary file.
         
     | 
| 180 | 
         
            +
                """
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                vocab_files_names = VOCAB_FILES_NAMES
         
     | 
| 183 | 
         
            +
                max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
         
     | 
| 184 | 
         
            +
                model_input_names = ["input_ids"]
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                def __init__(
         
     | 
| 187 | 
         
            +
                        self,
         
     | 
| 188 | 
         
            +
                        vocab_file,
         
     | 
| 189 | 
         
            +
                        do_lower_case=False,
         
     | 
| 190 | 
         
            +
                        remove_space=False,
         
     | 
| 191 | 
         
            +
                        bos_token='sop',
         
     | 
| 192 | 
         
            +
                        eos_token='eos',
         
     | 
| 193 | 
         
            +
                        eop_token='eop',
         
     | 
| 194 | 
         
            +
                        mask_token='[MASK]',
         
     | 
| 195 | 
         
            +
                        gmask_token='[gMASK]',
         
     | 
| 196 | 
         
            +
                        padding_side="left",
         
     | 
| 197 | 
         
            +
                        **kwargs
         
     | 
| 198 | 
         
            +
                ) -> None:
         
     | 
| 199 | 
         
            +
                    super().__init__(
         
     | 
| 200 | 
         
            +
                        do_lower_case=do_lower_case,
         
     | 
| 201 | 
         
            +
                        remove_space=remove_space,
         
     | 
| 202 | 
         
            +
                        padding_side=padding_side,
         
     | 
| 203 | 
         
            +
                        **kwargs
         
     | 
| 204 | 
         
            +
                    )
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    self.do_lower_case = do_lower_case
         
     | 
| 207 | 
         
            +
                    self.remove_space = remove_space
         
     | 
| 208 | 
         
            +
                    self.vocab_file = vocab_file
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    self.bos_token = bos_token
         
     | 
| 211 | 
         
            +
                    self.eos_token = eos_token
         
     | 
| 212 | 
         
            +
                    self.eop_token = eop_token
         
     | 
| 213 | 
         
            +
                    self.mask_token = mask_token
         
     | 
| 214 | 
         
            +
                    self.gMASK_token = gmask_token
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    self.sp_tokenizer = SPTokenizer(vocab_file)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    """ Initialisation """
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                @property
         
     | 
| 221 | 
         
            +
                def eop_token_id(self) -> Optional[int]:
         
     | 
| 222 | 
         
            +
                    """
         
     | 
| 223 | 
         
            +
                    `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
         
     | 
| 224 | 
         
            +
                    set.
         
     | 
| 225 | 
         
            +
                    """
         
     | 
| 226 | 
         
            +
                    if self.eop_token is None:
         
     | 
| 227 | 
         
            +
                        return None
         
     | 
| 228 | 
         
            +
                    return self.convert_tokens_to_ids(self.eop_token)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                @property
         
     | 
| 231 | 
         
            +
                def vocab_size(self):
         
     | 
| 232 | 
         
            +
                    """ Returns vocab size """
         
     | 
| 233 | 
         
            +
                    return self.sp_tokenizer.num_tokens
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                def get_vocab(self):
         
     | 
| 236 | 
         
            +
                    """ Returns vocab as a dict """
         
     | 
| 237 | 
         
            +
                    vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
         
     | 
| 238 | 
         
            +
                    vocab.update(self.added_tokens_encoder)
         
     | 
| 239 | 
         
            +
                    return vocab
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                def preprocess_text(self, inputs):
         
     | 
| 242 | 
         
            +
                    if self.remove_space:
         
     | 
| 243 | 
         
            +
                        outputs = " ".join(inputs.strip().split())
         
     | 
| 244 | 
         
            +
                    else:
         
     | 
| 245 | 
         
            +
                        outputs = inputs
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    if self.do_lower_case:
         
     | 
| 248 | 
         
            +
                        outputs = outputs.lower()
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    return outputs
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                def _tokenize(self, text, **kwargs):
         
     | 
| 253 | 
         
            +
                    """ Returns a tokenized string. """
         
     | 
| 254 | 
         
            +
                    text = self.preprocess_text(text)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    seq = self.sp_tokenizer.tokenize(text)
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    return seq
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                def decode(
         
     | 
| 261 | 
         
            +
                        self,
         
     | 
| 262 | 
         
            +
                        token_ids: Union[List[int], List[List[int]]],
         
     | 
| 263 | 
         
            +
                        skip_special_tokens: bool = False,
         
     | 
| 264 | 
         
            +
                        clean_up_tokenization_spaces: bool = True,
         
     | 
| 265 | 
         
            +
                        spaces_between_special_tokens: bool = True,
         
     | 
| 266 | 
         
            +
                        **kwargs
         
     | 
| 267 | 
         
            +
                ) -> str:
         
     | 
| 268 | 
         
            +
                    if isinstance(token_ids[0], list):
         
     | 
| 269 | 
         
            +
                        tokens = []
         
     | 
| 270 | 
         
            +
                        for single_token_ids in token_ids:
         
     | 
| 271 | 
         
            +
                            if self.pad_token_id in single_token_ids:  # remove pad
         
     | 
| 272 | 
         
            +
                                single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
         
     | 
| 273 | 
         
            +
                            tokens.append(self.sp_tokenizer.decode(single_token_ids))
         
     | 
| 274 | 
         
            +
                        return (tokens)
         
     | 
| 275 | 
         
            +
                    else:
         
     | 
| 276 | 
         
            +
                        if self.pad_token_id in token_ids:  # remove pad
         
     | 
| 277 | 
         
            +
                            token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
         
     | 
| 278 | 
         
            +
                        return self.sp_tokenizer.decode(token_ids)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def _convert_token_to_id(self, token):
         
     | 
| 281 | 
         
            +
                    """ Converts a token (str) in an id using the vocab. """
         
     | 
| 282 | 
         
            +
                    return self.sp_tokenizer[token]
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                def _convert_id_to_token(self, index):
         
     | 
| 285 | 
         
            +
                    """Converts an index (integer) in a token (str) using the vocab."""
         
     | 
| 286 | 
         
            +
                    return self.sp_tokenizer[index]
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                def save_vocabulary(self, save_directory, filename_prefix=None):
         
     | 
| 289 | 
         
            +
                    """
         
     | 
| 290 | 
         
            +
                    Save the vocabulary and special tokens file to a directory.
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    Args:
         
     | 
| 293 | 
         
            +
                        save_directory (`str`):
         
     | 
| 294 | 
         
            +
                            The directory in which to save the vocabulary.
         
     | 
| 295 | 
         
            +
                        filename_prefix (`str`, *optional*):
         
     | 
| 296 | 
         
            +
                            An optional prefix to add to the named of the saved files.
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    Returns:
         
     | 
| 299 | 
         
            +
                        `Tuple(str)`: Paths to the files saved.
         
     | 
| 300 | 
         
            +
                    """
         
     | 
| 301 | 
         
            +
                    if os.path.isdir(save_directory):
         
     | 
| 302 | 
         
            +
                        vocab_file = os.path.join(
         
     | 
| 303 | 
         
            +
                            save_directory, VOCAB_FILES_NAMES["vocab_file"]
         
     | 
| 304 | 
         
            +
                        )
         
     | 
| 305 | 
         
            +
                    else:
         
     | 
| 306 | 
         
            +
                        vocab_file = save_directory
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    with open(self.vocab_file, 'rb') as fin:
         
     | 
| 309 | 
         
            +
                        proto_str = fin.read()
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    with open(vocab_file, "wb") as writer:
         
     | 
| 312 | 
         
            +
                        writer.write(proto_str)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                    return (vocab_file,)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                def build_inputs_with_special_tokens(
         
     | 
| 317 | 
         
            +
                        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
         
     | 
| 318 | 
         
            +
                ) -> List[int]:
         
     | 
| 319 | 
         
            +
                    """
         
     | 
| 320 | 
         
            +
                    Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
         
     | 
| 321 | 
         
            +
                    adding special tokens. A BERT sequence has the following format:
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    - single sequence: `[CLS] X [SEP]`
         
     | 
| 324 | 
         
            +
                    - pair of sequences: `[CLS] A [SEP] B [SEP]`
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    Args:
         
     | 
| 327 | 
         
            +
                        token_ids_0 (`List[int]`):
         
     | 
| 328 | 
         
            +
                            List of IDs to which the special tokens will be added.
         
     | 
| 329 | 
         
            +
                        token_ids_1 (`List[int]`, *optional*):
         
     | 
| 330 | 
         
            +
                            Optional second list of IDs for sequence pairs.
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    Returns:
         
     | 
| 333 | 
         
            +
                        `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
         
     | 
| 334 | 
         
            +
                    """
         
     | 
| 335 | 
         
            +
                    if token_ids_1 is not None:
         
     | 
| 336 | 
         
            +
                        token_ids_0 += token_ids_1
         
     | 
| 337 | 
         
            +
                    mask_ids = self.sp_tokenizer[self.mask_token]
         
     | 
| 338 | 
         
            +
                    gmask_ids = self.sp_tokenizer[self.gMASK_token]
         
     | 
| 339 | 
         
            +
                    if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
         
     | 
| 340 | 
         
            +
                        token_ids_0 += [gmask_ids]
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
         
     | 
| 343 | 
         
            +
                        token_ids_0 += [self.sp_tokenizer[self.eos_token]]
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    token_ids_0 += [self.sp_tokenizer[self.bos_token]]
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    return token_ids_0
         
     | 
    	
        tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "name_or_path": "THUDM/chatglm-6b",
         
     | 
| 3 | 
         
            +
              "bos_token": "<sop>",
         
     | 
| 4 | 
         
            +
              "eop_token": "<eop>",
         
     | 
| 5 | 
         
            +
              "eos_token": "</s>",
         
     | 
| 6 | 
         
            +
              "gmask_token": "[gMASK]",
         
     | 
| 7 | 
         
            +
              "mask_token": "[MASK]",
         
     | 
| 8 | 
         
            +
              "pad_token": "<pad>",
         
     | 
| 9 | 
         
            +
              "unk_token": "<unk>",
         
     | 
| 10 | 
         
            +
              "remove_space": false,
         
     | 
| 11 | 
         
            +
              "do_lower_case": false,
         
     | 
| 12 | 
         
            +
              "tokenizer_class": "ChatGLMTokenizer",
         
     | 
| 13 | 
         
            +
              "auto_map": {
         
     | 
| 14 | 
         
            +
                "AutoTokenizer": [
         
     | 
| 15 | 
         
            +
                  "tokenization_chatglm.ChatGLMTokenizer",
         
     | 
| 16 | 
         
            +
                  null
         
     | 
| 17 | 
         
            +
                  ]
         
     | 
| 18 | 
         
            +
              }
         
     | 
| 19 | 
         
            +
            }
         
     |