Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
 - LICENSE +202 -0
 - Notice +79 -0
 - README.md +179 -3
 - added_tokens.json +28 -0
 - config.json +410 -0
 - configuration_aimv2.py +82 -0
 - configuration_ovis_u1.py +281 -0
 - configuration_yak.py +63 -0
 - merges.txt +0 -0
 - model-00001-of-00003.safetensors +3 -0
 - model-00002-of-00003.safetensors +3 -0
 - model-00003-of-00003.safetensors +3 -0
 - model.safetensors.index.json +0 -0
 - modeling_aimv2.py +385 -0
 - modeling_ovis_u1.py +921 -0
 - modeling_yak.py +1461 -0
 - preprocessor_config.json +32 -0
 - special_tokens_map.json +31 -0
 - tokenizer.json +3 -0
 - tokenizer_config.json +240 -0
 - vocab.json +0 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            tokenizer.json filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,202 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
                                             Apache License
         
     | 
| 3 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 4 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
               1. Definitions.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 11 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 14 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 17 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 18 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 19 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 20 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 21 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 22 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 25 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 28 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 29 | 
         
            +
                  source, and configuration files.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 32 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 33 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 34 | 
         
            +
                  and conversions to other media types.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 37 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 38 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 39 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 42 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 43 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 44 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 45 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 46 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 47 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 50 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 51 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 52 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 53 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 54 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 55 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 56 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 57 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 58 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 59 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 60 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 61 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 64 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 65 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 68 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 69 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 70 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 71 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 72 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 75 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 76 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 77 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 78 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 79 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 80 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 81 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 82 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 83 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 84 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 85 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 86 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 87 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 88 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 91 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 92 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 93 | 
         
            +
                  meet the following conditions:
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 96 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 99 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 102 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 103 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 104 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 105 | 
         
            +
                      the Derivative Works; and
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 108 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 109 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 110 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 111 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 112 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 113 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 114 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 115 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 116 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 117 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 118 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 119 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 120 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 121 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 122 | 
         
            +
                      as modifying the License.
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 125 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 126 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 127 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 128 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 129 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 132 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 133 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 134 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 135 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 136 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 137 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 140 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 141 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 142 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 145 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 146 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 147 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 148 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 149 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 150 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 151 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 152 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 155 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 156 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 157 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 158 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 159 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 160 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 161 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 162 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 163 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 164 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 167 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 168 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 169 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 170 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 171 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 172 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 173 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 174 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 175 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 182 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 183 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 184 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 185 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 186 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 187 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 188 | 
         
            +
                  identification within third-party archives.
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
               Copyright [yyyy] [name of copyright owner]
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 193 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 194 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 199 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 200 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 201 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 202 | 
         
            +
               limitations under the License.
         
     | 
    	
        Notice
    ADDED
    
    | 
         @@ -0,0 +1,79 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Copyright (C) 2025 AIDC-AI
         
     | 
| 2 | 
         
            +
            Licensed under the Apache License, Version 2.0 (the "License").
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            This model was trained based on the following model:
         
     | 
| 5 | 
         
            +
            1. Ovis2-2B https://huggingface.co/AIDC-AI/Ovis2-2B
         
     | 
| 6 | 
         
            +
            License: Apache License, Version 2.0 (https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md, SPDX-License-identifier: Apache-2.0)
         
     | 
| 7 | 
         
            +
            Apache License
         
     | 
| 8 | 
         
            +
                                   Version 2.0, January 2004
         
     | 
| 9 | 
         
            +
                                http://www.apache.org/licenses/
         
     | 
| 10 | 
         
            +
            TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            Definitions.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            END OF TERMS AND CONDITIONS
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            APPENDIX: How to apply the Apache License to your work.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
              To apply the Apache License to your work, attach the following
         
     | 
| 57 | 
         
            +
              boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 58 | 
         
            +
              replaced with your own identifying information. (Don't include
         
     | 
| 59 | 
         
            +
              the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 60 | 
         
            +
              comment syntax for the file format. We also recommend that a
         
     | 
| 61 | 
         
            +
              file or class name and description of purpose be included on the
         
     | 
| 62 | 
         
            +
              same "printed page" as the copyright notice for easier
         
     | 
| 63 | 
         
            +
              identification within third-party archives.
         
     | 
| 64 | 
         
            +
            Copyright [yyyy] [name of copyright owner]
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
               http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 69 | 
         
            +
            Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            2. sdxl-vae https://huggingface.co/stabilityai/sdxl-vae
         
     | 
| 72 | 
         
            +
            License: MIT (https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/mit.md, SPDX-License-identifier: MIT)
         
     | 
| 73 | 
         
            +
            MIT License
         
     | 
| 74 | 
         
            +
            Copyright (c) [year] [fullname]
         
     | 
| 75 | 
         
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
         
     | 
| 76 | 
         
            +
            The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
         
     | 
| 77 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,3 +1,179 @@ 
     | 
|
| 1 | 
         
            -
            ---
         
     | 
| 2 | 
         
            -
            license: apache-2.0
         
     | 
| 3 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            license: apache-2.0
         
     | 
| 3 | 
         
            +
            language:
         
     | 
| 4 | 
         
            +
            - en
         
     | 
| 5 | 
         
            +
            ---
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            # Ovis-U1
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            <div align="center">
         
     | 
| 10 | 
         
            +
              <img src=https://cdn-uploads.huggingface.co/production/uploads/637aebed7ce76c3b834cea37/3IK823BZ8w-mz_QfeYkDn.png width="30%"/>
         
     | 
| 11 | 
         
            +
            </div>
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            <p align="center">
         
     | 
| 14 | 
         
            +
              <!-- <a href="https://arxiv.org/abs/2502.12579"><img src="https://img.shields.io/badge/arXiv%20paper-2502.12579-b31b1b.svg" alt="arxiv"></a> -->
         
     | 
| 15 | 
         
            +
              <a href="https://github.com/AIDC-AI/Ovis"><img src="https://img.shields.io/badge/GitHub-AIDC--AI/Ovis--U1-blue?style=flat&logo=github" alt="demo"></a>
         
     | 
| 16 | 
         
            +
              <a href="https://huggingface.co/spaces/AIDC-AI/Ovis-U1-3B"><img src="https://img.shields.io/badge/🎨_HF_Spaces-AIDC--AI/Ovis--U1--3B-lightblack" alt="demo"></a>
         
     | 
| 17 | 
         
            +
              <a href="https://huggingface.co/AIDC-AI/Ovis-U1-3B"><img src="https://img.shields.io/badge/🤗_Model-AIDC--AI/Ovis--U1--3B-yellow" alt="model"></a>
         
     | 
| 18 | 
         
            +
            </p>
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            Building on the foundation of the Ovis series, Ovis-U1 is a 3-billion-parameter unified model that integrates multimodal understanding, text-to-image generation, and image editing capabilities. 
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            <figure>
         
     | 
| 24 | 
         
            +
              <img src="https://cdn-uploads.huggingface.co/production/uploads/636f4c6b5d2050767e4a1491/EmEEGmot9JzaBfHP2uWld.jpeg" alt="Ovis-U1 architecture">
         
     | 
| 25 | 
         
            +
              <figcaption style="text-align: center;">The overall architecture of Ovis-U1 (cf. Fig.2 in our report).</figcaption>
         
     | 
| 26 | 
         
            +
            </figure>
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            ---
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            ## 🚀 News
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            - [2025/6/28] 🔥 Announcing Ovis-U1-3B ([Model](https://huggingface.co/AIDC-AI/Ovis-U1-3B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis-U1-3B))!
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            ---
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            ## 📦 Installation
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            Ovis-U1 has been tested with Python 3.10, Torch 2.4.0, Transformers 4.51.3, and DeepSpeed 0.15.4. For a comprehensive list of package dependencies, please consult the requirements.txt file.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            ```bash
         
     | 
| 41 | 
         
            +
            git clone [email protected]:AIDC-AI/Ovis-U1.git
         
     | 
| 42 | 
         
            +
            conda create -n ovis-u1 python=3.10 -y
         
     | 
| 43 | 
         
            +
            conda activate ovis-u1
         
     | 
| 44 | 
         
            +
            cd Ovis-U1
         
     | 
| 45 | 
         
            +
            pip install -r requirements.txt
         
     | 
| 46 | 
         
            +
            pip install -e .
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            ```
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            ## 📂 Model Checkpoints
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            We provide pretrained Ovis-U1-3B checkpoints for easy download and evaluation:
         
     | 
| 53 | 
         
            +
             
         
     | 
| 54 | 
         
            +
            - **Model Repository**: [](https://huggingface.co/AIDC-AI/Ovis-U1-3B)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            ## 🛠️ Inference
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            For multimodal understanding, please run
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            ```bash
         
     | 
| 62 | 
         
            +
            python ovis/eval/test_txt_generation.py
         
     | 
| 63 | 
         
            +
            ```
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            For text-to-image, please run
         
     | 
| 66 | 
         
            +
            ```bash
         
     | 
| 67 | 
         
            +
            python ovis/eval/test_t2i.py \
         
     | 
| 68 | 
         
            +
                --height 1024 \
         
     | 
| 69 | 
         
            +
                --width 1024  \
         
     | 
| 70 | 
         
            +
                --steps 50 \
         
     | 
| 71 | 
         
            +
                --seed 42 \
         
     | 
| 72 | 
         
            +
                --txt_cfg 5  
         
     | 
| 73 | 
         
            +
            ```
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            For image editing, please run
         
     | 
| 76 | 
         
            +
            ```bash
         
     | 
| 77 | 
         
            +
            python ovis/eval/test_img_edit.py \
         
     | 
| 78 | 
         
            +
                --steps 50 \
         
     | 
| 79 | 
         
            +
                --img_cfg 4 \
         
     | 
| 80 | 
         
            +
                --txt_cfg 7.5  
         
     | 
| 81 | 
         
            +
            ```
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            ## 📊 Performance
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            #### OpenCompass Multi-modal Academic Benchmarks
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            | Model | MMB | MMS | MMMU | MathVista | Hallusion | AI2D | OCRBench | MMVet | Avg |
         
     | 
| 88 | 
         
            +
            |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
         
     | 
| 89 | 
         
            +
            | GPT-4o | 86 | 70.2 | 72.9 | 71.6 | 57 | 86.3 | 82.2 | 76.9 | **75.4** |
         
     | 
| 90 | 
         
            +
            | InternVL2.5-2B | 70.9 | 54.3 | 43.2 | 51.1 | 42.3 | 74.9 | 80.2 | 62.6 | **59.9** |
         
     | 
| 91 | 
         
            +
            | SAIL-VL-2B | 73.7 | 56.5 | 44.1 | 62.8 | 45.9 | 77.4 | 83.1 | 44.2 | **61** |
         
     | 
| 92 | 
         
            +
            | InternVL3-2B | 78 | 61.1 | 48.7 | 57.6 | 41.9 | 78.6 | 83.1 | 67 | **61.1**|
         
     | 
| 93 | 
         
            +
            | Qwen2.5-VL-3B | 76.8 | 56.3 | 51.2 | 61.2 | 46.6 | 81.4 | 82.8 | 60 | **64.5** |
         
     | 
| 94 | 
         
            +
            | Ovis2-2B | 76.9 | 56.7 | 45.6 | 64.1 | 50.2 | 82.7 | 87.3 | 58.3 | **65.2** |
         
     | 
| 95 | 
         
            +
            | SAIL-VL-1.5-2B | 78.5 | 62.6 | 46.4 | 67 | 50 | 83.7 | 89.1 | 58.8 | **67** |
         
     | 
| 96 | 
         
            +
            | Ristretto-3B | 80.2 | 62.8 | 51.3 | 67.6 | 50.2 | 84.2 | 84.7 | 60.7 | **67.7** |
         
     | 
| 97 | 
         
            +
            | Ovis-U1 | 77.8 | 61.3 | 51.1 | 69.4 | 56.3 | 85.6 | 88.3 | 66.7 | **69.6** |
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            #### GenEval
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            | Model | Single object | Two object | Counting | Colors | Position | Attribute binding | Overall |
         
     | 
| 102 | 
         
            +
            |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
         
     | 
| 103 | 
         
            +
            | GPT-4o | 0.99 | 0.92 | 0.85 | 0.92 | 0.75 | 0.61 | **0.84** |
         
     | 
| 104 | 
         
            +
            | BAGEL | 0.99 | 0.94 | 0.81 | 0.88 | 0.64 | 0.63 | **0.82** |
         
     | 
| 105 | 
         
            +
            | BAGEL 📝 | 0.98 | 0.95 | 0.84 | 0.95 | 0.78 | 0.77 | **0.88** |
         
     | 
| 106 | 
         
            +
            | UniWorld-V1 | 0.99 | 0.93 | 0.79 | 0.89 | 0.49 | 0.70 | **0.80** |
         
     | 
| 107 | 
         
            +
            | UniWorld-V1 📝 | 0.98 | 0.93 | 0.81 | 0.89 | 0.74 | 0.71 | **0.84** |
         
     | 
| 108 | 
         
            +
            | OmniGen | 0.98 | 0.84 | 0.66 | 0.74 | 0.40 | 0.43 | **0.68** | 
         
     | 
| 109 | 
         
            +
            | OmniGen2 | 1 | 0.95 | 0.64 | 0.88 | 0.55 | 0.76 | **0.80** |
         
     | 
| 110 | 
         
            +
            | OmniGen2 📝 | 0.99 | 0.96 | 0.74 | 0.98 | 0.71 | 0.75 | **0.86** |
         
     | 
| 111 | 
         
            +
            | Ovis-U1 | 0.98 | 0.98 | 0.90 | 0.92 | 0.79 | 0.75 | **0.89** |
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            *📝 denotes using the rewritten prompts*
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            #### DPG-Bench
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            | Model | Global | Entity | Attribute | Relation | Other | Overall |
         
     | 
| 118 | 
         
            +
            |:---:|:---:|:---:|:---:|:---:|:---:|:---:|
         
     | 
| 119 | 
         
            +
            | BAGEL | 88.94 | 90.37 | 91.29 | 90.82 | 88.67 | **85.07** |
         
     | 
| 120 | 
         
            +
            | UniWorld-V1 | 83.64 | 88.39 | 88.44 | 89.27 | 87.22 | **81.38** |
         
     | 
| 121 | 
         
            +
            | OmniGen | 87.90 | 88.97 | 88.47 | 87.95 | 83.56 | **81.16** |
         
     | 
| 122 | 
         
            +
            | OmniGen2 | 88.81 | 88.83 | 90.18 | 89.37 | 90.27 | **83.57** |
         
     | 
| 123 | 
         
            +
            | Ovis-U1 | 82.37 | 90.08 | 88.68 | 93.35 | 85.20 | **83.72** |
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            #### ImgEdit-Bench
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            | Model | Add | Adjust | Extract | Replace | Remove | Background | Style | Hybrid | Action | Overall |
         
     | 
| 128 | 
         
            +
            |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
         
     | 
| 129 | 
         
            +
            | GPT-4o | 4.61 | 4.33 | 2.9 | 4.35 | 3.66 | 4.57 | 4.93 | 3.96 | 4.89 | **4.2** |
         
     | 
| 130 | 
         
            +
            | MagicBrush | 2.84 | 1.58 | 1.51 | 1.97 | 1.58 | 1.75 | 2.38 | 1.62 | 1.22 | **1.90** |
         
     | 
| 131 | 
         
            +
            | Instruct-P2P | 2.45 | 1.83 | 1.44 | 2.01 | 1.50 | 1.44 | 3.55 | 1.2 | 1.46 | **1.88** |
         
     | 
| 132 | 
         
            +
            | AnyEdit | 3.18 | 2.95 | 1.88 | 2.47 | 2.23 | 2.24 | 2.85 | 1.56 | 2.65 | **2.45** |
         
     | 
| 133 | 
         
            +
            | UltraEdit | 3.44 | 2.81 | 2.13 | 2.96 | 1.45 | 2.83 | 3.76 | 1.91 | 2.98 | **2.7** |
         
     | 
| 134 | 
         
            +
            | OmniGen | 3.47 | 3.04 | 1.71 | 2.94 | 2.43 | 3.21 | 4.19 | 2.24 | 3.38 | **2.96** |
         
     | 
| 135 | 
         
            +
            | Step1X-Edit | 3.88 | 3.14 | 1.76 | 3.40 | 2.41 | 3.16 | 4.63 | 2.64 | 2.52 | **3.06** |
         
     | 
| 136 | 
         
            +
            | ICEdit | 3.58 | 3.39 | 1.73 | 3.15 | 2.93 | 3.08 | 3.84 | 2.04 | 3.68 | **3.05** |
         
     | 
| 137 | 
         
            +
            | BAGEL | 3.56 | 3.31 | 1.7 | 3.3 | 2.62 | 3.24 | 4.49 | 2.38 | 4.17 | **3.2** |
         
     | 
| 138 | 
         
            +
            | UniWorld-V1 | 3.82 | 3.64 | 2.27 | 3.47 | 3.24 | 2.99 | 4.21 | 2.96 | 2.74 | **3.26** |
         
     | 
| 139 | 
         
            +
            | OmniGen2 | 3.57 | 3.06 | 1.77 | 3.74 | 3.2 | 3.57 | 4.81 | 2.52 | 4.68 | **3.44** |
         
     | 
| 140 | 
         
            +
            | Ovis-U1 | 4.13 | 3.62 | 2.98 | 4.45 | 4.06 | 4.22 | 4.69 | 3.45 | 4.61 | **4.00** |
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
            #### GEdit-Bench-EN
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            |  Model | Background Change | Color Alteration   | Material Modification  | Motion Change | Portrait Beautification  | Style Transfer  | Subject Addition  | Subject Removal  | Subject Replacement  | Text Modification  | Tone Transformation  | Avg |
         
     | 
| 145 | 
         
            +
            |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
         
     | 
| 146 | 
         
            +
            | GPT-4o | 7.205 |	6.491 |	6.607 | 8.096 |	7.768 |	6.961 |	7.622 |	8.331 |	8.067 |	7.427 |	8.301 |	**7.534** |
         
     | 
| 147 | 
         
            +
            | AnyEdit | 4.663	| 4.260 |	2.537 |	2.024 |	3.479	| 2.032 |	3.995 |	3.089 |	3.180 |	0.922 |	5.151 |	**3.212** |
         
     | 
| 148 | 
         
            +
            | Instruct-Pix2Pix | 3.825 |	5.182 |	3.688 |	3.509 |	4.339 |	4.560 |	3.461 |	2.031 |	4.237 |	0.955 |	4.733 |	**3.684** |
         
     | 
| 149 | 
         
            +
            | MagicBrush |	5.637 |	5.136 |	5.078 |	4.513 |	4.487 |	4.439 |	5.252 |	3.704 |	4.941 |	1.384 |	5.130 |	**4.518** |
         
     | 
| 150 | 
         
            +
            | OmniGen | 5.281 |	6.003 |	5.308 |	2.916 |	3.087 |	4.903 |	6.628 |	6.352 |	5.616 |	4.519 |	5.064 |	**5.062** |
         
     | 
| 151 | 
         
            +
            | Gemini |	6.781 |	6.369 |	6.040 |	6.938 |	5.591 |	4.676 |	7.501 |	6.447 |	7.003 |	5.765 |	6.350 |	**6.315** |
         
     | 
| 152 | 
         
            +
            | Step1X-Edit |	6.547 |	6.545 |	6.204 |	6.483 |	6.787 |	7.221 |	6.975 |	6.512 |	7.068 |	6.921 |	6.448 |	**6.701** |
         
     | 
| 153 | 
         
            +
            | Doubao |	7.430 |	7.095 |	6.339 |	6.973 |	6.972 |	6.767 |	7.674 |	6.748 |	7.447 |	3.471 |	7.383 |	**6.754** |
         
     | 
| 154 | 
         
            +
            | BAGEL | 7.324 |	6.909 |	6.381 |	4.753 |	4.573 |	6.150 |	7.896 |	7.164 |	7.021 |	7.320 |	6.218 |	**6.519** |
         
     | 
| 155 | 
         
            +
            | Ovis-U1 | 7.486 |	6.879 |	6.208 |	4.790 |	5.981 |	6.463 |	7.491 |	7.254 |	7.266 |	4.482 |	6.314 |	**6.420** |
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            ## 📚 Citation
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
            If you find Ovis-U1 useful, please cite our paper:
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            ```bibtex
         
     | 
| 162 | 
         
            +
            @inproceedings{wang2025ovisu1,
         
     | 
| 163 | 
         
            +
            title={Ovis-U1 Technical Report},
         
     | 
| 164 | 
         
            +
            author={Ovis Team},
         
     | 
| 165 | 
         
            +
            year={2025}
         
     | 
| 166 | 
         
            +
            }
         
     | 
| 167 | 
         
            +
            ```
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            ## 🙏 Acknowledgments
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
            The code is built upon [Ovis](https://github.com/AIDC-AI/Ovis) and [FLUX](https://github.com/black-forest-labs/flux).
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            ## 📄 License
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            The project is released under Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0, SPDX-License-identifier: Apache-2.0).
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
            ## 🚨 Disclaimer
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            We used compliance checking algorithms during the training process, to ensure the compliance of the trained model to the best of our ability. Due to complex data and the diversity of language model usage scenarios, we cannot guarantee that the model is completely free of copyright issues or improper content. If you believe anything infringes on your rights or generates improper content, please contact us, and we will promptly address the matter.
         
     | 
    	
        added_tokens.json
    ADDED
    
    | 
         @@ -0,0 +1,28 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "</think>": 151668,
         
     | 
| 3 | 
         
            +
              "</tool_call>": 151658,
         
     | 
| 4 | 
         
            +
              "</tool_response>": 151666,
         
     | 
| 5 | 
         
            +
              "<think>": 151667,
         
     | 
| 6 | 
         
            +
              "<tool_call>": 151657,
         
     | 
| 7 | 
         
            +
              "<tool_response>": 151665,
         
     | 
| 8 | 
         
            +
              "<|box_end|>": 151649,
         
     | 
| 9 | 
         
            +
              "<|box_start|>": 151648,
         
     | 
| 10 | 
         
            +
              "<|endoftext|>": 151643,
         
     | 
| 11 | 
         
            +
              "<|file_sep|>": 151664,
         
     | 
| 12 | 
         
            +
              "<|fim_middle|>": 151660,
         
     | 
| 13 | 
         
            +
              "<|fim_pad|>": 151662,
         
     | 
| 14 | 
         
            +
              "<|fim_prefix|>": 151659,
         
     | 
| 15 | 
         
            +
              "<|fim_suffix|>": 151661,
         
     | 
| 16 | 
         
            +
              "<|im_end|>": 151645,
         
     | 
| 17 | 
         
            +
              "<|im_start|>": 151644,
         
     | 
| 18 | 
         
            +
              "<|image_pad|>": 151655,
         
     | 
| 19 | 
         
            +
              "<|object_ref_end|>": 151647,
         
     | 
| 20 | 
         
            +
              "<|object_ref_start|>": 151646,
         
     | 
| 21 | 
         
            +
              "<|quad_end|>": 151651,
         
     | 
| 22 | 
         
            +
              "<|quad_start|>": 151650,
         
     | 
| 23 | 
         
            +
              "<|repo_name|>": 151663,
         
     | 
| 24 | 
         
            +
              "<|video_pad|>": 151656,
         
     | 
| 25 | 
         
            +
              "<|vision_end|>": 151653,
         
     | 
| 26 | 
         
            +
              "<|vision_pad|>": 151654,
         
     | 
| 27 | 
         
            +
              "<|vision_start|>": 151652
         
     | 
| 28 | 
         
            +
            }
         
     | 
    	
        config.json
    ADDED
    
    | 
         @@ -0,0 +1,410 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "architectures": [
         
     | 
| 3 | 
         
            +
                "OvisU1"
         
     | 
| 4 | 
         
            +
              ],
         
     | 
| 5 | 
         
            +
              "auto_map": {
         
     | 
| 6 | 
         
            +
                "AutoConfig": "configuration_ovis_u1.OvisU1Config",
         
     | 
| 7 | 
         
            +
                "AutoModelForCausalLM": "modeling_ovis_u1.OvisU1"
         
     | 
| 8 | 
         
            +
              },
         
     | 
| 9 | 
         
            +
              "conversation_formatter_class": "Qwen3ConversationFormatter",
         
     | 
| 10 | 
         
            +
              "disable_tie_weight": false,
         
     | 
| 11 | 
         
            +
              "hidden_size": 2048,
         
     | 
| 12 | 
         
            +
              "llm_attn_implementation": null,
         
     | 
| 13 | 
         
            +
              "llm_config": {
         
     | 
| 14 | 
         
            +
                "_attn_implementation_autoset": true,
         
     | 
| 15 | 
         
            +
                "_name_or_path": "Qwen/Qwen3-1.7B",
         
     | 
| 16 | 
         
            +
                "add_cross_attention": false,
         
     | 
| 17 | 
         
            +
                "architectures": [
         
     | 
| 18 | 
         
            +
                  "Qwen3ForCausalLM"
         
     | 
| 19 | 
         
            +
                ],
         
     | 
| 20 | 
         
            +
                "attention_bias": false,
         
     | 
| 21 | 
         
            +
                "attention_dropout": 0.0,
         
     | 
| 22 | 
         
            +
                "bad_words_ids": null,
         
     | 
| 23 | 
         
            +
                "begin_suppress_tokens": null,
         
     | 
| 24 | 
         
            +
                "bos_token_id": 151643,
         
     | 
| 25 | 
         
            +
                "chunk_size_feed_forward": 0,
         
     | 
| 26 | 
         
            +
                "cross_attention_hidden_size": null,
         
     | 
| 27 | 
         
            +
                "decoder_start_token_id": null,
         
     | 
| 28 | 
         
            +
                "diversity_penalty": 0.0,
         
     | 
| 29 | 
         
            +
                "do_sample": false,
         
     | 
| 30 | 
         
            +
                "early_stopping": false,
         
     | 
| 31 | 
         
            +
                "encoder_no_repeat_ngram_size": 0,
         
     | 
| 32 | 
         
            +
                "eos_token_id": 151645,
         
     | 
| 33 | 
         
            +
                "exponential_decay_length_penalty": null,
         
     | 
| 34 | 
         
            +
                "finetuning_task": null,
         
     | 
| 35 | 
         
            +
                "forced_bos_token_id": null,
         
     | 
| 36 | 
         
            +
                "forced_eos_token_id": null,
         
     | 
| 37 | 
         
            +
                "head_dim": 128,
         
     | 
| 38 | 
         
            +
                "hidden_act": "silu",
         
     | 
| 39 | 
         
            +
                "hidden_size": 2048,
         
     | 
| 40 | 
         
            +
                "id2label": {
         
     | 
| 41 | 
         
            +
                  "0": "LABEL_0",
         
     | 
| 42 | 
         
            +
                  "1": "LABEL_1"
         
     | 
| 43 | 
         
            +
                },
         
     | 
| 44 | 
         
            +
                "initializer_range": 0.02,
         
     | 
| 45 | 
         
            +
                "intermediate_size": 6144,
         
     | 
| 46 | 
         
            +
                "is_decoder": false,
         
     | 
| 47 | 
         
            +
                "is_encoder_decoder": false,
         
     | 
| 48 | 
         
            +
                "label2id": {
         
     | 
| 49 | 
         
            +
                  "LABEL_0": 0,
         
     | 
| 50 | 
         
            +
                  "LABEL_1": 1
         
     | 
| 51 | 
         
            +
                },
         
     | 
| 52 | 
         
            +
                "length_penalty": 1.0,
         
     | 
| 53 | 
         
            +
                "max_length": 20,
         
     | 
| 54 | 
         
            +
                "max_position_embeddings": 40960,
         
     | 
| 55 | 
         
            +
                "max_window_layers": 28,
         
     | 
| 56 | 
         
            +
                "min_length": 0,
         
     | 
| 57 | 
         
            +
                "model_type": "qwen3",
         
     | 
| 58 | 
         
            +
                "no_repeat_ngram_size": 0,
         
     | 
| 59 | 
         
            +
                "num_attention_heads": 16,
         
     | 
| 60 | 
         
            +
                "num_beam_groups": 1,
         
     | 
| 61 | 
         
            +
                "num_beams": 1,
         
     | 
| 62 | 
         
            +
                "num_hidden_layers": 28,
         
     | 
| 63 | 
         
            +
                "num_key_value_heads": 8,
         
     | 
| 64 | 
         
            +
                "num_return_sequences": 1,
         
     | 
| 65 | 
         
            +
                "output_attentions": false,
         
     | 
| 66 | 
         
            +
                "output_hidden_states": false,
         
     | 
| 67 | 
         
            +
                "output_scores": false,
         
     | 
| 68 | 
         
            +
                "pad_token_id": null,
         
     | 
| 69 | 
         
            +
                "prefix": null,
         
     | 
| 70 | 
         
            +
                "problem_type": null,
         
     | 
| 71 | 
         
            +
                "pruned_heads": {},
         
     | 
| 72 | 
         
            +
                "remove_invalid_values": false,
         
     | 
| 73 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 74 | 
         
            +
                "return_dict": true,
         
     | 
| 75 | 
         
            +
                "return_dict_in_generate": false,
         
     | 
| 76 | 
         
            +
                "rms_norm_eps": 1e-06,
         
     | 
| 77 | 
         
            +
                "rope_scaling": null,
         
     | 
| 78 | 
         
            +
                "rope_theta": 1000000,
         
     | 
| 79 | 
         
            +
                "sep_token_id": null,
         
     | 
| 80 | 
         
            +
                "sliding_window": null,
         
     | 
| 81 | 
         
            +
                "suppress_tokens": null,
         
     | 
| 82 | 
         
            +
                "task_specific_params": null,
         
     | 
| 83 | 
         
            +
                "temperature": 1.0,
         
     | 
| 84 | 
         
            +
                "tf_legacy_loss": false,
         
     | 
| 85 | 
         
            +
                "tie_encoder_decoder": false,
         
     | 
| 86 | 
         
            +
                "tie_word_embeddings": true,
         
     | 
| 87 | 
         
            +
                "tokenizer_class": null,
         
     | 
| 88 | 
         
            +
                "top_k": 50,
         
     | 
| 89 | 
         
            +
                "top_p": 1.0,
         
     | 
| 90 | 
         
            +
                "torch_dtype": "bfloat16",
         
     | 
| 91 | 
         
            +
                "torchscript": false,
         
     | 
| 92 | 
         
            +
                "typical_p": 1.0,
         
     | 
| 93 | 
         
            +
                "use_bfloat16": false,
         
     | 
| 94 | 
         
            +
                "use_cache": true,
         
     | 
| 95 | 
         
            +
                "use_sliding_window": false,
         
     | 
| 96 | 
         
            +
                "vocab_size": 151936
         
     | 
| 97 | 
         
            +
              },
         
     | 
| 98 | 
         
            +
              "model_type": "ovis_u1",
         
     | 
| 99 | 
         
            +
              "multimodal_max_length": 4496,
         
     | 
| 100 | 
         
            +
              "torch_dtype": "bfloat16",
         
     | 
| 101 | 
         
            +
              "transformers_version": "4.51.3",
         
     | 
| 102 | 
         
            +
              "use_cache": false,
         
     | 
| 103 | 
         
            +
              "visual_generator_config": {
         
     | 
| 104 | 
         
            +
                "_attn_implementation_autoset": true,
         
     | 
| 105 | 
         
            +
                "_name_or_path": "yak_1b/yak_1b_qwen3_trained_S8",
         
     | 
| 106 | 
         
            +
                "add_cross_attention": false,
         
     | 
| 107 | 
         
            +
                "architectures": [
         
     | 
| 108 | 
         
            +
                  "YakModel"
         
     | 
| 109 | 
         
            +
                ],
         
     | 
| 110 | 
         
            +
                "axes_dim": [
         
     | 
| 111 | 
         
            +
                  16,
         
     | 
| 112 | 
         
            +
                  56,
         
     | 
| 113 | 
         
            +
                  56
         
     | 
| 114 | 
         
            +
                ],
         
     | 
| 115 | 
         
            +
                "bad_words_ids": null,
         
     | 
| 116 | 
         
            +
                "base_shift": 0.5,
         
     | 
| 117 | 
         
            +
                "begin_suppress_tokens": null,
         
     | 
| 118 | 
         
            +
                "bos_token_id": null,
         
     | 
| 119 | 
         
            +
                "checkpoint": false,
         
     | 
| 120 | 
         
            +
                "chunk_size_feed_forward": 0,
         
     | 
| 121 | 
         
            +
                "context_in_dim": 4096,
         
     | 
| 122 | 
         
            +
                "cross_attention_hidden_size": null,
         
     | 
| 123 | 
         
            +
                "decoder_start_token_id": null,
         
     | 
| 124 | 
         
            +
                "depth": 6,
         
     | 
| 125 | 
         
            +
                "depth_single_blocks": 12,
         
     | 
| 126 | 
         
            +
                "diversity_penalty": 0.0,
         
     | 
| 127 | 
         
            +
                "do_sample": false,
         
     | 
| 128 | 
         
            +
                "early_stopping": false,
         
     | 
| 129 | 
         
            +
                "encoder_no_repeat_ngram_size": 0,
         
     | 
| 130 | 
         
            +
                "eos_token_id": null,
         
     | 
| 131 | 
         
            +
                "exponential_decay_length_penalty": null,
         
     | 
| 132 | 
         
            +
                "finetuning_task": null,
         
     | 
| 133 | 
         
            +
                "forced_bos_token_id": null,
         
     | 
| 134 | 
         
            +
                "forced_eos_token_id": null,
         
     | 
| 135 | 
         
            +
                "guidance_embed": false,
         
     | 
| 136 | 
         
            +
                "hidden_size": 1536,
         
     | 
| 137 | 
         
            +
                "id2label": {
         
     | 
| 138 | 
         
            +
                  "0": "LABEL_0",
         
     | 
| 139 | 
         
            +
                  "1": "LABEL_1"
         
     | 
| 140 | 
         
            +
                },
         
     | 
| 141 | 
         
            +
                "in_channels": 16,
         
     | 
| 142 | 
         
            +
                "is_decoder": false,
         
     | 
| 143 | 
         
            +
                "is_encoder_decoder": false,
         
     | 
| 144 | 
         
            +
                "label2id": {
         
     | 
| 145 | 
         
            +
                  "LABEL_0": 0,
         
     | 
| 146 | 
         
            +
                  "LABEL_1": 1
         
     | 
| 147 | 
         
            +
                },
         
     | 
| 148 | 
         
            +
                "length_penalty": 1.0,
         
     | 
| 149 | 
         
            +
                "max_length": 20,
         
     | 
| 150 | 
         
            +
                "max_shift": 1.15,
         
     | 
| 151 | 
         
            +
                "min_length": 0,
         
     | 
| 152 | 
         
            +
                "mlp_ratio": 4.0,
         
     | 
| 153 | 
         
            +
                "model_type": "yak",
         
     | 
| 154 | 
         
            +
                "no_repeat_ngram_size": 0,
         
     | 
| 155 | 
         
            +
                "num_beam_groups": 1,
         
     | 
| 156 | 
         
            +
                "num_beams": 1,
         
     | 
| 157 | 
         
            +
                "num_heads": 12,
         
     | 
| 158 | 
         
            +
                "num_return_sequences": 1,
         
     | 
| 159 | 
         
            +
                "out_channels": 16,
         
     | 
| 160 | 
         
            +
                "output_attentions": false,
         
     | 
| 161 | 
         
            +
                "output_hidden_states": false,
         
     | 
| 162 | 
         
            +
                "output_scores": false,
         
     | 
| 163 | 
         
            +
                "pad_token_id": null,
         
     | 
| 164 | 
         
            +
                "prefix": null,
         
     | 
| 165 | 
         
            +
                "problem_type": null,
         
     | 
| 166 | 
         
            +
                "pruned_heads": {},
         
     | 
| 167 | 
         
            +
                "qkv_bias": true,
         
     | 
| 168 | 
         
            +
                "remove_invalid_values": false,
         
     | 
| 169 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 170 | 
         
            +
                "return_dict": true,
         
     | 
| 171 | 
         
            +
                "return_dict_in_generate": false,
         
     | 
| 172 | 
         
            +
                "sep_token_id": null,
         
     | 
| 173 | 
         
            +
                "suppress_tokens": null,
         
     | 
| 174 | 
         
            +
                "task_specific_params": null,
         
     | 
| 175 | 
         
            +
                "temperature": 1.0,
         
     | 
| 176 | 
         
            +
                "tf_legacy_loss": false,
         
     | 
| 177 | 
         
            +
                "theta": 10000,
         
     | 
| 178 | 
         
            +
                "tie_encoder_decoder": false,
         
     | 
| 179 | 
         
            +
                "tie_word_embeddings": true,
         
     | 
| 180 | 
         
            +
                "timestep_shift": true,
         
     | 
| 181 | 
         
            +
                "tokenizer_class": null,
         
     | 
| 182 | 
         
            +
                "top_k": 50,
         
     | 
| 183 | 
         
            +
                "top_p": 1.0,
         
     | 
| 184 | 
         
            +
                "torch_dtype": "float32",
         
     | 
| 185 | 
         
            +
                "torchscript": false,
         
     | 
| 186 | 
         
            +
                "txt_type": "refiner",
         
     | 
| 187 | 
         
            +
                "typical_p": 1.0,
         
     | 
| 188 | 
         
            +
                "use_bfloat16": false,
         
     | 
| 189 | 
         
            +
                "vae_config": {
         
     | 
| 190 | 
         
            +
                  "_class_name": "AutoencoderKL",
         
     | 
| 191 | 
         
            +
                  "_diffusers_version": "0.18.0.dev0",
         
     | 
| 192 | 
         
            +
                  "_name_or_path": ".",
         
     | 
| 193 | 
         
            +
                  "act_fn": "silu",
         
     | 
| 194 | 
         
            +
                  "block_out_channels": [
         
     | 
| 195 | 
         
            +
                    128,
         
     | 
| 196 | 
         
            +
                    256,
         
     | 
| 197 | 
         
            +
                    512,
         
     | 
| 198 | 
         
            +
                    512
         
     | 
| 199 | 
         
            +
                  ],
         
     | 
| 200 | 
         
            +
                  "down_block_types": [
         
     | 
| 201 | 
         
            +
                    "DownEncoderBlock2D",
         
     | 
| 202 | 
         
            +
                    "DownEncoderBlock2D",
         
     | 
| 203 | 
         
            +
                    "DownEncoderBlock2D",
         
     | 
| 204 | 
         
            +
                    "DownEncoderBlock2D"
         
     | 
| 205 | 
         
            +
                  ],
         
     | 
| 206 | 
         
            +
                  "force_upcast": false,
         
     | 
| 207 | 
         
            +
                  "in_channels": 3,
         
     | 
| 208 | 
         
            +
                  "latent_channels": 4,
         
     | 
| 209 | 
         
            +
                  "layers_per_block": 2,
         
     | 
| 210 | 
         
            +
                  "norm_num_groups": 32,
         
     | 
| 211 | 
         
            +
                  "out_channels": 3,
         
     | 
| 212 | 
         
            +
                  "sample_size": 512,
         
     | 
| 213 | 
         
            +
                  "scaling_factor": 0.13025,
         
     | 
| 214 | 
         
            +
                  "up_block_types": [
         
     | 
| 215 | 
         
            +
                    "UpDecoderBlock2D",
         
     | 
| 216 | 
         
            +
                    "UpDecoderBlock2D",
         
     | 
| 217 | 
         
            +
                    "UpDecoderBlock2D",
         
     | 
| 218 | 
         
            +
                    "UpDecoderBlock2D"
         
     | 
| 219 | 
         
            +
                  ]
         
     | 
| 220 | 
         
            +
                },
         
     | 
| 221 | 
         
            +
                "vec_in_dim": 1536
         
     | 
| 222 | 
         
            +
              },
         
     | 
| 223 | 
         
            +
              "visual_tokenizer_config": {
         
     | 
| 224 | 
         
            +
                "_attn_implementation_autoset": true,
         
     | 
| 225 | 
         
            +
                "_name_or_path": "",
         
     | 
| 226 | 
         
            +
                "add_cross_attention": false,
         
     | 
| 227 | 
         
            +
                "architectures": null,
         
     | 
| 228 | 
         
            +
                "backbone_config": {
         
     | 
| 229 | 
         
            +
                  "_attn_implementation_autoset": false,
         
     | 
| 230 | 
         
            +
                  "_name_or_path": "aimv2/visual_tokenizer_backbone",
         
     | 
| 231 | 
         
            +
                  "add_cross_attention": false,
         
     | 
| 232 | 
         
            +
                  "architectures": [
         
     | 
| 233 | 
         
            +
                    "AIMv2Model"
         
     | 
| 234 | 
         
            +
                  ],
         
     | 
| 235 | 
         
            +
                  "attention_dropout": 0.0,
         
     | 
| 236 | 
         
            +
                  "auto_map": {
         
     | 
| 237 | 
         
            +
                    "AutoConfig": "configuration_aimv2.AIMv2Config",
         
     | 
| 238 | 
         
            +
                    "AutoModel": "modeling_aimv2.AIMv2Model",
         
     | 
| 239 | 
         
            +
                    "FlaxAutoModel": "modeling_flax_aimv2.FlaxAIMv2Model"
         
     | 
| 240 | 
         
            +
                  },
         
     | 
| 241 | 
         
            +
                  "bad_words_ids": null,
         
     | 
| 242 | 
         
            +
                  "begin_suppress_tokens": null,
         
     | 
| 243 | 
         
            +
                  "bos_token_id": null,
         
     | 
| 244 | 
         
            +
                  "chunk_size_feed_forward": 0,
         
     | 
| 245 | 
         
            +
                  "cross_attention_hidden_size": null,
         
     | 
| 246 | 
         
            +
                  "decoder_start_token_id": null,
         
     | 
| 247 | 
         
            +
                  "disable_rope": false,
         
     | 
| 248 | 
         
            +
                  "diversity_penalty": 0.0,
         
     | 
| 249 | 
         
            +
                  "do_sample": false,
         
     | 
| 250 | 
         
            +
                  "early_stopping": false,
         
     | 
| 251 | 
         
            +
                  "encoder_no_repeat_ngram_size": 0,
         
     | 
| 252 | 
         
            +
                  "eos_token_id": null,
         
     | 
| 253 | 
         
            +
                  "exponential_decay_length_penalty": null,
         
     | 
| 254 | 
         
            +
                  "finetuning_task": null,
         
     | 
| 255 | 
         
            +
                  "forced_bos_token_id": null,
         
     | 
| 256 | 
         
            +
                  "forced_eos_token_id": null,
         
     | 
| 257 | 
         
            +
                  "fullatt_block_indexes": null,
         
     | 
| 258 | 
         
            +
                  "hidden_size": 1024,
         
     | 
| 259 | 
         
            +
                  "hidden_stride": 2,
         
     | 
| 260 | 
         
            +
                  "id2label": {
         
     | 
| 261 | 
         
            +
                    "0": "LABEL_0",
         
     | 
| 262 | 
         
            +
                    "1": "LABEL_1"
         
     | 
| 263 | 
         
            +
                  },
         
     | 
| 264 | 
         
            +
                  "image_size": 448,
         
     | 
| 265 | 
         
            +
                  "intermediate_size": 2816,
         
     | 
| 266 | 
         
            +
                  "interpolate_pe_method": "two_dim",
         
     | 
| 267 | 
         
            +
                  "is_decoder": false,
         
     | 
| 268 | 
         
            +
                  "is_encoder_decoder": false,
         
     | 
| 269 | 
         
            +
                  "label2id": {
         
     | 
| 270 | 
         
            +
                    "LABEL_0": 0,
         
     | 
| 271 | 
         
            +
                    "LABEL_1": 1
         
     | 
| 272 | 
         
            +
                  },
         
     | 
| 273 | 
         
            +
                  "length_penalty": 1.0,
         
     | 
| 274 | 
         
            +
                  "max_length": 20,
         
     | 
| 275 | 
         
            +
                  "max_pixels": 2408448,
         
     | 
| 276 | 
         
            +
                  "min_length": 0,
         
     | 
| 277 | 
         
            +
                  "min_pixels": 200704,
         
     | 
| 278 | 
         
            +
                  "model_type": "aimv2",
         
     | 
| 279 | 
         
            +
                  "no_repeat_ngram_size": 0,
         
     | 
| 280 | 
         
            +
                  "num_attention_heads": 8,
         
     | 
| 281 | 
         
            +
                  "num_beam_groups": 1,
         
     | 
| 282 | 
         
            +
                  "num_beams": 1,
         
     | 
| 283 | 
         
            +
                  "num_channels": 3,
         
     | 
| 284 | 
         
            +
                  "num_hidden_layers": 24,
         
     | 
| 285 | 
         
            +
                  "num_return_sequences": 1,
         
     | 
| 286 | 
         
            +
                  "output_attentions": false,
         
     | 
| 287 | 
         
            +
                  "output_hidden_states": false,
         
     | 
| 288 | 
         
            +
                  "output_scores": false,
         
     | 
| 289 | 
         
            +
                  "pad_token_id": null,
         
     | 
| 290 | 
         
            +
                  "patch_size": 14,
         
     | 
| 291 | 
         
            +
                  "prefix": null,
         
     | 
| 292 | 
         
            +
                  "preserve_original_pe": true,
         
     | 
| 293 | 
         
            +
                  "problem_type": null,
         
     | 
| 294 | 
         
            +
                  "projection_dropout": 0.0,
         
     | 
| 295 | 
         
            +
                  "pruned_heads": {},
         
     | 
| 296 | 
         
            +
                  "qkv_bias": false,
         
     | 
| 297 | 
         
            +
                  "remove_invalid_values": false,
         
     | 
| 298 | 
         
            +
                  "repetition_penalty": 1.0,
         
     | 
| 299 | 
         
            +
                  "return_dict": true,
         
     | 
| 300 | 
         
            +
                  "return_dict_in_generate": false,
         
     | 
| 301 | 
         
            +
                  "rms_norm_eps": 1e-05,
         
     | 
| 302 | 
         
            +
                  "sep_token_id": null,
         
     | 
| 303 | 
         
            +
                  "suppress_tokens": null,
         
     | 
| 304 | 
         
            +
                  "task_specific_params": null,
         
     | 
| 305 | 
         
            +
                  "temperature": 1.0,
         
     | 
| 306 | 
         
            +
                  "temporal_patch_size": 1,
         
     | 
| 307 | 
         
            +
                  "tf_legacy_loss": false,
         
     | 
| 308 | 
         
            +
                  "tie_encoder_decoder": false,
         
     | 
| 309 | 
         
            +
                  "tie_word_embeddings": true,
         
     | 
| 310 | 
         
            +
                  "tokenizer_class": null,
         
     | 
| 311 | 
         
            +
                  "top_k": 50,
         
     | 
| 312 | 
         
            +
                  "top_p": 1.0,
         
     | 
| 313 | 
         
            +
                  "torch_dtype": "bfloat16",
         
     | 
| 314 | 
         
            +
                  "torchscript": false,
         
     | 
| 315 | 
         
            +
                  "typical_p": 1.0,
         
     | 
| 316 | 
         
            +
                  "use_bfloat16": false,
         
     | 
| 317 | 
         
            +
                  "use_bias": false,
         
     | 
| 318 | 
         
            +
                  "window_size": 112
         
     | 
| 319 | 
         
            +
                },
         
     | 
| 320 | 
         
            +
                "backbone_kwargs": {
         
     | 
| 321 | 
         
            +
                  "disable_rope": false,
         
     | 
| 322 | 
         
            +
                  "hidden_stride": 2,
         
     | 
| 323 | 
         
            +
                  "interpolate_pe_method": "two_dim",
         
     | 
| 324 | 
         
            +
                  "max_pixels": 2408448,
         
     | 
| 325 | 
         
            +
                  "min_pixels": 200704,
         
     | 
| 326 | 
         
            +
                  "preserve_original_pe": true,
         
     | 
| 327 | 
         
            +
                  "temporal_patch_size": 1,
         
     | 
| 328 | 
         
            +
                  "window_size": 112
         
     | 
| 329 | 
         
            +
                },
         
     | 
| 330 | 
         
            +
                "bad_words_ids": null,
         
     | 
| 331 | 
         
            +
                "begin_suppress_tokens": null,
         
     | 
| 332 | 
         
            +
                "bos_token_id": null,
         
     | 
| 333 | 
         
            +
                "chunk_size_feed_forward": 0,
         
     | 
| 334 | 
         
            +
                "cross_attention_hidden_size": null,
         
     | 
| 335 | 
         
            +
                "decoder_start_token_id": null,
         
     | 
| 336 | 
         
            +
                "depths": null,
         
     | 
| 337 | 
         
            +
                "disable_rope": false,
         
     | 
| 338 | 
         
            +
                "diversity_penalty": 0.0,
         
     | 
| 339 | 
         
            +
                "do_sample": false,
         
     | 
| 340 | 
         
            +
                "drop_cls_token": false,
         
     | 
| 341 | 
         
            +
                "early_stopping": false,
         
     | 
| 342 | 
         
            +
                "encoder_no_repeat_ngram_size": 0,
         
     | 
| 343 | 
         
            +
                "eos_token_id": null,
         
     | 
| 344 | 
         
            +
                "exponential_decay_length_penalty": null,
         
     | 
| 345 | 
         
            +
                "finetuning_task": null,
         
     | 
| 346 | 
         
            +
                "forced_bos_token_id": null,
         
     | 
| 347 | 
         
            +
                "forced_eos_token_id": null,
         
     | 
| 348 | 
         
            +
                "fullatt_block_indexes": null,
         
     | 
| 349 | 
         
            +
                "hidden_stride": 2,
         
     | 
| 350 | 
         
            +
                "id2label": {
         
     | 
| 351 | 
         
            +
                  "0": "LABEL_0",
         
     | 
| 352 | 
         
            +
                  "1": "LABEL_1"
         
     | 
| 353 | 
         
            +
                },
         
     | 
| 354 | 
         
            +
                "image_processor_new_kwargs": {
         
     | 
| 355 | 
         
            +
                  "hidden_stride": 2,
         
     | 
| 356 | 
         
            +
                  "max_pixels": 2408448,
         
     | 
| 357 | 
         
            +
                  "min_pixels": 200704,
         
     | 
| 358 | 
         
            +
                  "temporal_patch_size": 1
         
     | 
| 359 | 
         
            +
                },
         
     | 
| 360 | 
         
            +
                "interpolate_pe_method": "two_dim",
         
     | 
| 361 | 
         
            +
                "is_decoder": false,
         
     | 
| 362 | 
         
            +
                "is_encoder_decoder": false,
         
     | 
| 363 | 
         
            +
                "label2id": {
         
     | 
| 364 | 
         
            +
                  "LABEL_0": 0,
         
     | 
| 365 | 
         
            +
                  "LABEL_1": 1
         
     | 
| 366 | 
         
            +
                },
         
     | 
| 367 | 
         
            +
                "length_penalty": 1.0,
         
     | 
| 368 | 
         
            +
                "max_length": 20,
         
     | 
| 369 | 
         
            +
                "max_pixels": 2408448,
         
     | 
| 370 | 
         
            +
                "min_length": 0,
         
     | 
| 371 | 
         
            +
                "min_pixels": 200704,
         
     | 
| 372 | 
         
            +
                "model_type": "aimv2_visual_tokenizer",
         
     | 
| 373 | 
         
            +
                "no_repeat_ngram_size": 0,
         
     | 
| 374 | 
         
            +
                "num_beam_groups": 1,
         
     | 
| 375 | 
         
            +
                "num_beams": 1,
         
     | 
| 376 | 
         
            +
                "num_return_sequences": 1,
         
     | 
| 377 | 
         
            +
                "output_attentions": false,
         
     | 
| 378 | 
         
            +
                "output_hidden_states": false,
         
     | 
| 379 | 
         
            +
                "output_scores": false,
         
     | 
| 380 | 
         
            +
                "pad_token_id": null,
         
     | 
| 381 | 
         
            +
                "prefix": null,
         
     | 
| 382 | 
         
            +
                "preserve_original_pe": true,
         
     | 
| 383 | 
         
            +
                "problem_type": null,
         
     | 
| 384 | 
         
            +
                "pruned_heads": {},
         
     | 
| 385 | 
         
            +
                "remove_invalid_values": false,
         
     | 
| 386 | 
         
            +
                "repetition_penalty": 1.0,
         
     | 
| 387 | 
         
            +
                "return_dict": true,
         
     | 
| 388 | 
         
            +
                "return_dict_in_generate": false,
         
     | 
| 389 | 
         
            +
                "sep_token_id": null,
         
     | 
| 390 | 
         
            +
                "suppress_tokens": null,
         
     | 
| 391 | 
         
            +
                "task_specific_params": null,
         
     | 
| 392 | 
         
            +
                "tau": 1.0,
         
     | 
| 393 | 
         
            +
                "temperature": 1.0,
         
     | 
| 394 | 
         
            +
                "temporal_patch_size": 1,
         
     | 
| 395 | 
         
            +
                "tf_legacy_loss": false,
         
     | 
| 396 | 
         
            +
                "tie_encoder_decoder": false,
         
     | 
| 397 | 
         
            +
                "tie_word_embeddings": true,
         
     | 
| 398 | 
         
            +
                "tokenize_function": "softmax",
         
     | 
| 399 | 
         
            +
                "tokenizer_class": null,
         
     | 
| 400 | 
         
            +
                "top_k": 50,
         
     | 
| 401 | 
         
            +
                "top_p": 1.0,
         
     | 
| 402 | 
         
            +
                "torch_dtype": null,
         
     | 
| 403 | 
         
            +
                "torchscript": false,
         
     | 
| 404 | 
         
            +
                "typical_p": 1.0,
         
     | 
| 405 | 
         
            +
                "use_bfloat16": false,
         
     | 
| 406 | 
         
            +
                "use_indicators": false,
         
     | 
| 407 | 
         
            +
                "vocab_size": 65536,
         
     | 
| 408 | 
         
            +
                "window_size": 112
         
     | 
| 409 | 
         
            +
              }
         
     | 
| 410 | 
         
            +
            }
         
     | 
    	
        configuration_aimv2.py
    ADDED
    
    | 
         @@ -0,0 +1,82 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # copied from https://huggingface.co/apple/aimv2-huge-patch14-448
         
     | 
| 2 | 
         
            +
            from typing import Any
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            __all__ = ["AIMv2Config"]
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class AIMv2Config(PretrainedConfig):
         
     | 
| 10 | 
         
            +
                """This is the configuration class to store the configuration of an [`AIMv2Model`].
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                Instantiating a configuration with the defaults will yield a similar configuration
         
     | 
| 13 | 
         
            +
                to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                Args:
         
     | 
| 16 | 
         
            +
                    hidden_size: Dimension of the hidden representations.
         
     | 
| 17 | 
         
            +
                    intermediate_size: Dimension of the SwiGLU representations.
         
     | 
| 18 | 
         
            +
                    num_hidden_layers: Number of hidden layers in the Transformer.
         
     | 
| 19 | 
         
            +
                    num_attention_heads: Number of attention heads for each attention layer
         
     | 
| 20 | 
         
            +
                        in the Transformer.
         
     | 
| 21 | 
         
            +
                    num_channels: Number of input channels.
         
     | 
| 22 | 
         
            +
                    image_size: Image size.
         
     | 
| 23 | 
         
            +
                    patch_size: Patch size.
         
     | 
| 24 | 
         
            +
                    rms_norm_eps: Epsilon value used for the RMS normalization layer.
         
     | 
| 25 | 
         
            +
                    attention_dropout: Dropout ratio for attention probabilities.
         
     | 
| 26 | 
         
            +
                    projection_dropout: Dropout ratio for the projection layer after the attention.
         
     | 
| 27 | 
         
            +
                    qkv_bias: Whether to add a bias to the queries, keys and values.
         
     | 
| 28 | 
         
            +
                    use_bias: Whether to add a bias in the feed-forward and projection layers.
         
     | 
| 29 | 
         
            +
                    kwargs: Keyword arguments for the [`PretrainedConfig`].
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                model_type: str = "aimv2"
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def __init__(
         
     | 
| 35 | 
         
            +
                    self,
         
     | 
| 36 | 
         
            +
                    hidden_size: int = 1024,
         
     | 
| 37 | 
         
            +
                    intermediate_size: int = 2816,
         
     | 
| 38 | 
         
            +
                    num_hidden_layers: int = 24,
         
     | 
| 39 | 
         
            +
                    num_attention_heads: int = 8,
         
     | 
| 40 | 
         
            +
                    num_channels: int = 3,
         
     | 
| 41 | 
         
            +
                    image_size: int = 224,
         
     | 
| 42 | 
         
            +
                    patch_size: int = 14,
         
     | 
| 43 | 
         
            +
                    rms_norm_eps: float = 1e-5,
         
     | 
| 44 | 
         
            +
                    attention_dropout: float = 0.0,
         
     | 
| 45 | 
         
            +
                    projection_dropout: float = 0.0,
         
     | 
| 46 | 
         
            +
                    qkv_bias: bool = False,
         
     | 
| 47 | 
         
            +
                    use_bias: bool = False,
         
     | 
| 48 | 
         
            +
                    hidden_stride: int = 2,
         
     | 
| 49 | 
         
            +
                    window_size: int = 112,
         
     | 
| 50 | 
         
            +
                    fullatt_block_indexes: list = None,
         
     | 
| 51 | 
         
            +
                    temporal_patch_size: int = 1,
         
     | 
| 52 | 
         
            +
                    preserve_original_pe: bool = False,
         
     | 
| 53 | 
         
            +
                    interpolate_pe_method: str = 'one_dim',
         
     | 
| 54 | 
         
            +
                    disable_rope: bool = False,
         
     | 
| 55 | 
         
            +
                    min_pixels: int = 3136,
         
     | 
| 56 | 
         
            +
                    max_pixels: int = 1960000,
         
     | 
| 57 | 
         
            +
                    **kwargs: Any,
         
     | 
| 58 | 
         
            +
                ):
         
     | 
| 59 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 60 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 61 | 
         
            +
                    self.intermediate_size = intermediate_size
         
     | 
| 62 | 
         
            +
                    self.num_hidden_layers = num_hidden_layers
         
     | 
| 63 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 64 | 
         
            +
                    self.num_channels = num_channels
         
     | 
| 65 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 66 | 
         
            +
                    self.image_size = image_size
         
     | 
| 67 | 
         
            +
                    self.attention_dropout = attention_dropout
         
     | 
| 68 | 
         
            +
                    self.rms_norm_eps = rms_norm_eps
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    self.projection_dropout = projection_dropout
         
     | 
| 71 | 
         
            +
                    self.qkv_bias = qkv_bias
         
     | 
| 72 | 
         
            +
                    self.use_bias = use_bias
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    self.hidden_stride = hidden_stride
         
     | 
| 75 | 
         
            +
                    self.window_size = window_size
         
     | 
| 76 | 
         
            +
                    self.fullatt_block_indexes = fullatt_block_indexes
         
     | 
| 77 | 
         
            +
                    self.temporal_patch_size = temporal_patch_size
         
     | 
| 78 | 
         
            +
                    self.preserve_original_pe = preserve_original_pe
         
     | 
| 79 | 
         
            +
                    self.interpolate_pe_method = interpolate_pe_method
         
     | 
| 80 | 
         
            +
                    self.disable_rope = disable_rope
         
     | 
| 81 | 
         
            +
                    self.min_pixels = min_pixels
         
     | 
| 82 | 
         
            +
                    self.max_pixels = max_pixels
         
     | 
    	
        configuration_ovis_u1.py
    ADDED
    
    | 
         @@ -0,0 +1,281 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List, Dict, Union, Optional
         
     | 
| 2 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from transformers import PretrainedConfig, AutoConfig, AutoModel
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # Model Constants
         
     | 
| 7 | 
         
            +
            IGNORE_ID = -100
         
     | 
| 8 | 
         
            +
            IMAGE_TOKEN_ID = -200
         
     | 
| 9 | 
         
            +
            VIDEO_TOKEN_ID = -201 
         
     | 
| 10 | 
         
            +
            IMAGE_TOKEN = "<image>"
         
     | 
| 11 | 
         
            +
            VIDEO_TOKEN = "<video>"
         
     | 
| 12 | 
         
            +
            IMAGE_ATOM_ID = -300
         
     | 
| 13 | 
         
            +
            IMAGE_INDICATOR_IDS = [-301, -302, -303, -304]
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from .configuration_aimv2 import AIMv2Config
         
     | 
| 16 | 
         
            +
            from .modeling_aimv2 import AIMv2Model
         
     | 
| 17 | 
         
            +
            AutoConfig.register("aimv2", AIMv2Config)
         
     | 
| 18 | 
         
            +
            AutoModel.register(AIMv2Config, AIMv2Model)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .configuration_yak import YakConfig
         
     | 
| 21 | 
         
            +
            from .modeling_yak import YakModel
         
     | 
| 22 | 
         
            +
            AutoConfig.register("yak", YakConfig)
         
     | 
| 23 | 
         
            +
            AutoModel.register(YakConfig, YakModel)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 27 | 
         
            +
            #                     Visual Tokenizer Configuration
         
     | 
| 28 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 29 | 
         
            +
            class BaseVisualTokenizerConfig(PretrainedConfig):
         
     | 
| 30 | 
         
            +
                def __init__(self,
         
     | 
| 31 | 
         
            +
                    vocab_size=16384,
         
     | 
| 32 | 
         
            +
                    tokenize_function="softmax",
         
     | 
| 33 | 
         
            +
                    tau=1.0,
         
     | 
| 34 | 
         
            +
                    depths=None,
         
     | 
| 35 | 
         
            +
                    use_indicators=False,
         
     | 
| 36 | 
         
            +
                    drop_cls_token=False,
         
     | 
| 37 | 
         
            +
                    backbone_config: Optional[Union[PretrainedConfig, dict]] = None,
         
     | 
| 38 | 
         
            +
                    hidden_stride: int = 1,
         
     | 
| 39 | 
         
            +
                    **kwargs
         
     | 
| 40 | 
         
            +
                ):
         
     | 
| 41 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 42 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 43 | 
         
            +
                    self.tokenize_function = tokenize_function
         
     | 
| 44 | 
         
            +
                    self.tau = tau
         
     | 
| 45 | 
         
            +
                    if isinstance(depths, str):
         
     | 
| 46 | 
         
            +
                        depths = [int(x) for x in depths.split('|')]
         
     | 
| 47 | 
         
            +
                    self.depths = depths
         
     | 
| 48 | 
         
            +
                    self.backbone_kwargs = {}
         
     | 
| 49 | 
         
            +
                    self.use_indicators = use_indicators
         
     | 
| 50 | 
         
            +
                    self.drop_cls_token = drop_cls_token
         
     | 
| 51 | 
         
            +
                    if backbone_config is not None:
         
     | 
| 52 | 
         
            +
                        assert isinstance(backbone_config, (PretrainedConfig, dict)), \
         
     | 
| 53 | 
         
            +
                            f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
         
     | 
| 54 | 
         
            +
                        if not isinstance(backbone_config, PretrainedConfig):
         
     | 
| 55 | 
         
            +
                            model_type = backbone_config['model_type']
         
     | 
| 56 | 
         
            +
                            backbone_config.pop('model_type')
         
     | 
| 57 | 
         
            +
                            backbone_config = AutoConfig.for_model(model_type, **backbone_config)
         
     | 
| 58 | 
         
            +
                    self.backbone_config = backbone_config
         
     | 
| 59 | 
         
            +
                    self.hidden_stride = hidden_stride
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
         
     | 
| 63 | 
         
            +
                model_type = "aimv2_visual_tokenizer"
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def __init__(self, **kwargs):
         
     | 
| 66 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 67 | 
         
            +
                    if self.drop_cls_token:
         
     | 
| 68 | 
         
            +
                        self.drop_cls_token = False
         
     | 
| 69 | 
         
            +
                    if self.depths:
         
     | 
| 70 | 
         
            +
                        assert len(self.depths) == 1
         
     | 
| 71 | 
         
            +
                        self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    self.image_processor_new_kwargs = {}
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    if kwargs.get("min_pixels", None) is not None:
         
     | 
| 76 | 
         
            +
                        self.image_processor_new_kwargs['min_pixels'] = kwargs.get("min_pixels")
         
     | 
| 77 | 
         
            +
                        self.backbone_kwargs['min_pixels'] = self.min_pixels
         
     | 
| 78 | 
         
            +
                    
         
     | 
| 79 | 
         
            +
                    if kwargs.get("max_pixels", None) is not None:
         
     | 
| 80 | 
         
            +
                        self.image_processor_new_kwargs['max_pixels'] = kwargs.get("max_pixels")
         
     | 
| 81 | 
         
            +
                        self.backbone_kwargs['max_pixels'] = self.max_pixels
         
     | 
| 82 | 
         
            +
                    
         
     | 
| 83 | 
         
            +
                    if kwargs.get("temporal_patch_size", None) is not None:
         
     | 
| 84 | 
         
            +
                        self.image_processor_new_kwargs['temporal_patch_size'] = kwargs.get("temporal_patch_size")
         
     | 
| 85 | 
         
            +
                        self.backbone_kwargs['temporal_patch_size'] = self.temporal_patch_size
         
     | 
| 86 | 
         
            +
                    
         
     | 
| 87 | 
         
            +
                    if kwargs.get("hidden_stride", None) is not None:
         
     | 
| 88 | 
         
            +
                        self.image_processor_new_kwargs['hidden_stride'] = kwargs.get("hidden_stride")
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    if kwargs.get("patch_size", None) is not None:
         
     | 
| 91 | 
         
            +
                        self.image_processor_new_kwargs['patch_size'] = kwargs.get("patch_size")
         
     | 
| 92 | 
         
            +
                        self.backbone_kwargs['patch_size'] = self.patch_size
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    if kwargs.get("window_size", None) is not None:
         
     | 
| 95 | 
         
            +
                        self.backbone_kwargs['window_size'] = kwargs.get("window_size")
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    if kwargs.get("hidden_stride", None) is not None:
         
     | 
| 98 | 
         
            +
                        self.backbone_kwargs['hidden_stride'] = kwargs.get("hidden_stride")
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    if kwargs.get('fullatt_block_indexes', None) is not None:
         
     | 
| 101 | 
         
            +
                        self.backbone_kwargs['fullatt_block_indexes'] = [int(i) for i in kwargs.get('fullatt_block_indexes').replace(' ','').split('|')]
         
     | 
| 102 | 
         
            +
                    
         
     | 
| 103 | 
         
            +
                    if kwargs.get("preserve_original_pe", None) is not None:
         
     | 
| 104 | 
         
            +
                        self.backbone_kwargs['preserve_original_pe'] = kwargs.get("preserve_original_pe")
         
     | 
| 105 | 
         
            +
                    
         
     | 
| 106 | 
         
            +
                    if kwargs.get("interpolate_pe_method", None) is not None:
         
     | 
| 107 | 
         
            +
                        self.backbone_kwargs['interpolate_pe_method'] = kwargs.get("interpolate_pe_method")
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    if kwargs.get("disable_rope", None) is not None:
         
     | 
| 110 | 
         
            +
                        self.backbone_kwargs['disable_rope'] = kwargs.get("disable_rope")
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 117 | 
         
            +
            #                          OvisU1 Configuration
         
     | 
| 118 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 119 | 
         
            +
            class OvisU1Config(PretrainedConfig):
         
     | 
| 120 | 
         
            +
                model_type = "ovis_u1"
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def __init__(self,
         
     | 
| 123 | 
         
            +
                             llm_config: Optional[Union[PretrainedConfig, dict]] = None,
         
     | 
| 124 | 
         
            +
                             visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
         
     | 
| 125 | 
         
            +
                             visual_generator_config: Optional[Union[PretrainedConfig, dict]] = None,
         
     | 
| 126 | 
         
            +
                             multimodal_max_length=2048,
         
     | 
| 127 | 
         
            +
                             hidden_size=None,
         
     | 
| 128 | 
         
            +
                             conversation_formatter_class=None,
         
     | 
| 129 | 
         
            +
                             llm_attn_implementation=None,
         
     | 
| 130 | 
         
            +
                             disable_tie_weight=False,
         
     | 
| 131 | 
         
            +
                             **kwargs):
         
     | 
| 132 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 133 | 
         
            +
                    if llm_config is not None:
         
     | 
| 134 | 
         
            +
                        assert isinstance(llm_config, (PretrainedConfig, dict)), \
         
     | 
| 135 | 
         
            +
                            f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
         
     | 
| 136 | 
         
            +
                        if not isinstance(llm_config, PretrainedConfig):
         
     | 
| 137 | 
         
            +
                            model_type = llm_config['model_type']
         
     | 
| 138 | 
         
            +
                            llm_config.pop('model_type')
         
     | 
| 139 | 
         
            +
                            llm_config = AutoConfig.for_model(model_type, **llm_config)
         
     | 
| 140 | 
         
            +
                    self.llm_config = llm_config
         
     | 
| 141 | 
         
            +
                    if visual_tokenizer_config is not None:
         
     | 
| 142 | 
         
            +
                        assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
         
     | 
| 143 | 
         
            +
                            f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
         
     | 
| 144 | 
         
            +
                        if not isinstance(visual_tokenizer_config, PretrainedConfig):
         
     | 
| 145 | 
         
            +
                            model_type = visual_tokenizer_config['model_type']
         
     | 
| 146 | 
         
            +
                            visual_tokenizer_config.pop('model_type')
         
     | 
| 147 | 
         
            +
                            if model_type == "aimv2_native_visual_tokenizer":
         
     | 
| 148 | 
         
            +
                                model_type = "aimv2_visual_tokenizer"
         
     | 
| 149 | 
         
            +
                            if visual_tokenizer_config['backbone_config']['model_type'] == "aimv2_native":
         
     | 
| 150 | 
         
            +
                                visual_tokenizer_config['backbone_config']['model_type'] = "aimv2"
         
     | 
| 151 | 
         
            +
                            visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
         
     | 
| 152 | 
         
            +
                    self.visual_tokenizer_config = visual_tokenizer_config
         
     | 
| 153 | 
         
            +
                    if visual_generator_config is not None:
         
     | 
| 154 | 
         
            +
                        assert isinstance(visual_generator_config, (PretrainedConfig, dict)), \
         
     | 
| 155 | 
         
            +
                            f"expect `visual_generator_config` to be instance of PretrainedConfig or dict, but got {type(visual_generator_config)} type"
         
     | 
| 156 | 
         
            +
                        if not isinstance(visual_generator_config, PretrainedConfig):
         
     | 
| 157 | 
         
            +
                            model_type = visual_generator_config['model_type']
         
     | 
| 158 | 
         
            +
                            visual_generator_config.pop('model_type')
         
     | 
| 159 | 
         
            +
                            visual_generator_config = AutoConfig.for_model(model_type, **visual_generator_config)
         
     | 
| 160 | 
         
            +
                    self.visual_generator_config = visual_generator_config
         
     | 
| 161 | 
         
            +
                    self.multimodal_max_length = multimodal_max_length
         
     | 
| 162 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 163 | 
         
            +
                    self.conversation_formatter_class = conversation_formatter_class
         
     | 
| 164 | 
         
            +
                    self.llm_attn_implementation = llm_attn_implementation
         
     | 
| 165 | 
         
            +
                    self.disable_tie_weight = disable_tie_weight
         
     | 
| 166 | 
         
            +
                    
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 169 | 
         
            +
            #                         Conversation Formatter
         
     | 
| 170 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 171 | 
         
            +
            class ConversationFormatter(ABC):
         
     | 
| 172 | 
         
            +
                support_tokenizer_types = None
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def __init__(self, tokenizer):
         
     | 
| 175 | 
         
            +
                    tokenizer_type = type(tokenizer).__name__
         
     | 
| 176 | 
         
            +
                    assert tokenizer_type in self.support_tokenizer_types, \
         
     | 
| 177 | 
         
            +
                        f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
         
     | 
| 178 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 179 | 
         
            +
                    self.image_token = IMAGE_TOKEN
         
     | 
| 180 | 
         
            +
                    self.image_token_id = IMAGE_TOKEN_ID
         
     | 
| 181 | 
         
            +
                    self.ignore_id = IGNORE_ID
         
     | 
| 182 | 
         
            +
                    self.im_end = None
         
     | 
| 183 | 
         
            +
                    self.video_token = VIDEO_TOKEN
         
     | 
| 184 | 
         
            +
                    self.video_token_id = VIDEO_TOKEN_ID
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                def _tokenize_with_image_symbol(self, text):
         
     | 
| 187 | 
         
            +
                    if text.find(self.video_token) != -1:
         
     | 
| 188 | 
         
            +
                        token = self.video_token
         
     | 
| 189 | 
         
            +
                        token_id = self.video_token_id
         
     | 
| 190 | 
         
            +
                    else:
         
     | 
| 191 | 
         
            +
                        token = self.image_token
         
     | 
| 192 | 
         
            +
                        token_id = self.image_token_id
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
         
     | 
| 195 | 
         
            +
                                   text.split(token)]
         
     | 
| 196 | 
         
            +
                    token_ids = []
         
     | 
| 197 | 
         
            +
                    num_chuck = len(text_chunks)
         
     | 
| 198 | 
         
            +
                    for i, chunk in enumerate(text_chunks):
         
     | 
| 199 | 
         
            +
                        token_ids.extend(chunk)
         
     | 
| 200 | 
         
            +
                        if i < num_chuck - 1:
         
     | 
| 201 | 
         
            +
                            token_ids.append(token_id)
         
     | 
| 202 | 
         
            +
                    return token_ids
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                @abstractmethod
         
     | 
| 205 | 
         
            +
                def format(self, conversations: List[Dict], generation_preface=None):
         
     | 
| 206 | 
         
            +
                    pass
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                @abstractmethod
         
     | 
| 209 | 
         
            +
                def format_query(self, query, generation_preface=""):
         
     | 
| 210 | 
         
            +
                    pass
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            class Qwen3ConversationFormatter(ConversationFormatter):
         
     | 
| 214 | 
         
            +
                support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def __init__(self, tokenizer):
         
     | 
| 217 | 
         
            +
                    super().__init__(tokenizer)
         
     | 
| 218 | 
         
            +
                    self.from2role = {
         
     | 
| 219 | 
         
            +
                        "system": "<|im_start|>system\n",
         
     | 
| 220 | 
         
            +
                        "human": "<|im_start|>user\n",
         
     | 
| 221 | 
         
            +
                        "gpt": "<|im_start|>assistant\n",
         
     | 
| 222 | 
         
            +
                        "ignored_gpt": "<|im_start|>assistant\n",
         
     | 
| 223 | 
         
            +
                    }
         
     | 
| 224 | 
         
            +
                    self.gpt_token_num = None
         
     | 
| 225 | 
         
            +
                    self.im_end = "<|im_end|>\n"
         
     | 
| 226 | 
         
            +
                    self.empty_think = "<think>\n\n</think>\n\n"
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def format(self, conversations: List[Dict], generation_preface=None, enable_thinking=False):
         
     | 
| 229 | 
         
            +
                    if self.gpt_token_num is None:
         
     | 
| 230 | 
         
            +
                        prefilled_think = "" if enable_thinking else self.empty_think
         
     | 
| 231 | 
         
            +
                        self.gpt_token_num = len(
         
     | 
| 232 | 
         
            +
                            self.tokenizer(self.from2role["gpt"] + prefilled_think, add_special_tokens=False).input_ids
         
     | 
| 233 | 
         
            +
                        )
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    if generation_preface is not None:
         
     | 
| 236 | 
         
            +
                        conversations.append({
         
     | 
| 237 | 
         
            +
                            "from": "gpt",
         
     | 
| 238 | 
         
            +
                            "value": generation_preface
         
     | 
| 239 | 
         
            +
                        })
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    prompt = ""
         
     | 
| 242 | 
         
            +
                    input_ids = []
         
     | 
| 243 | 
         
            +
                    labels = []
         
     | 
| 244 | 
         
            +
                    num_conversation = len(conversations)
         
     | 
| 245 | 
         
            +
                    for i, conversation in enumerate(conversations):
         
     | 
| 246 | 
         
            +
                        frm = conversation["from"]
         
     | 
| 247 | 
         
            +
                        role = self.from2role[frm]
         
     | 
| 248 | 
         
            +
                        message = conversation["value"]
         
     | 
| 249 | 
         
            +
                        if frm == 'gpt' and not enable_thinking:
         
     | 
| 250 | 
         
            +
                            text = role + self.empty_think + message
         
     | 
| 251 | 
         
            +
                        else:
         
     | 
| 252 | 
         
            +
                            text = role + message
         
     | 
| 253 | 
         
            +
                        if i < num_conversation - 1 or generation_preface is None:
         
     | 
| 254 | 
         
            +
                            text += self.im_end
         
     | 
| 255 | 
         
            +
                        prompt += text
         
     | 
| 256 | 
         
            +
                        token_ids = self._tokenize_with_image_symbol(text)
         
     | 
| 257 | 
         
            +
                        input_ids.extend(token_ids)
         
     | 
| 258 | 
         
            +
                        label_ids = [self.ignore_id] * len(token_ids)
         
     | 
| 259 | 
         
            +
                        if frm == "gpt" and generation_preface is None:
         
     | 
| 260 | 
         
            +
                            # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
         
     | 
| 261 | 
         
            +
                            label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
         
     | 
| 262 | 
         
            +
                        labels.extend(label_ids)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    assert self._tokenize_with_image_symbol(prompt) == input_ids
         
     | 
| 265 | 
         
            +
                    assert len(input_ids) == len(labels)
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    if conversations[-1]['from'] == "gpt" and generation_preface is None:
         
     | 
| 268 | 
         
            +
                        # remove the last `\n` following `im_end` in input_ids
         
     | 
| 269 | 
         
            +
                        input_ids.pop()
         
     | 
| 270 | 
         
            +
                        labels.pop()
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    return prompt, input_ids, labels
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def format_query(self, query, generation_preface="", enable_thinking=False):
         
     | 
| 275 | 
         
            +
                    prompt, input_ids, _ = self.format([{
         
     | 
| 276 | 
         
            +
                        "from": "human",
         
     | 
| 277 | 
         
            +
                        "value": query
         
     | 
| 278 | 
         
            +
                    }], generation_preface=generation_preface, enable_thinking=enable_thinking)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    return prompt, input_ids
         
     | 
| 281 | 
         
            +
             
     | 
    	
        configuration_yak.py
    ADDED
    
    | 
         @@ -0,0 +1,63 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Any
         
     | 
| 2 | 
         
            +
            from typing import Union, Optional
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            __all__ = ["YakConfig"]
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class YakConfig(PretrainedConfig):
         
     | 
| 10 | 
         
            +
                """This is the configuration class to store the configuration of an [`YakModel`].
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                Args:
         
     | 
| 13 | 
         
            +
                """
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                model_type: str = "yak"
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(
         
     | 
| 18 | 
         
            +
                    self,
         
     | 
| 19 | 
         
            +
                    in_channels: int = 16,
         
     | 
| 20 | 
         
            +
                    out_channels: int = 16,
         
     | 
| 21 | 
         
            +
                    vec_in_dim: int = 1536,
         
     | 
| 22 | 
         
            +
                    context_in_dim: int = 3072,
         
     | 
| 23 | 
         
            +
                    hidden_size: int = 1536,
         
     | 
| 24 | 
         
            +
                    mlp_ratio: int = 4,
         
     | 
| 25 | 
         
            +
                    num_heads: int = 12,
         
     | 
| 26 | 
         
            +
                    depth: int = 6,
         
     | 
| 27 | 
         
            +
                    depth_single_blocks: int = 12,
         
     | 
| 28 | 
         
            +
                    axes_dim: list = [16, 56, 56],
         
     | 
| 29 | 
         
            +
                    theta: int = 10_000,
         
     | 
| 30 | 
         
            +
                    qkv_bias: bool = True,
         
     | 
| 31 | 
         
            +
                    guidance_embed: bool = False,
         
     | 
| 32 | 
         
            +
                    checkpoint: bool = False,
         
     | 
| 33 | 
         
            +
                    txt_type: str = "refiner",
         
     | 
| 34 | 
         
            +
                    timestep_shift: bool = False,
         
     | 
| 35 | 
         
            +
                    base_shift: float = 0.5,
         
     | 
| 36 | 
         
            +
                    max_shift: float = 1.15,
         
     | 
| 37 | 
         
            +
                    vae_config: Optional[Union[PretrainedConfig, dict]] = None,
         
     | 
| 38 | 
         
            +
                    **kwargs: Any,
         
     | 
| 39 | 
         
            +
                ):
         
     | 
| 40 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 41 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 42 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 43 | 
         
            +
                    self.vec_in_dim = vec_in_dim
         
     | 
| 44 | 
         
            +
                    self.context_in_dim = context_in_dim
         
     | 
| 45 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 46 | 
         
            +
                    self.mlp_ratio = mlp_ratio
         
     | 
| 47 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 48 | 
         
            +
                    self.depth = depth
         
     | 
| 49 | 
         
            +
                    self.depth_single_blocks = depth_single_blocks
         
     | 
| 50 | 
         
            +
                    self.axes_dim = axes_dim
         
     | 
| 51 | 
         
            +
                    self.theta = theta
         
     | 
| 52 | 
         
            +
                    self.qkv_bias = qkv_bias
         
     | 
| 53 | 
         
            +
                    self.guidance_embed = guidance_embed
         
     | 
| 54 | 
         
            +
                    self.checkpoint = checkpoint
         
     | 
| 55 | 
         
            +
                    self.txt_type = txt_type
         
     | 
| 56 | 
         
            +
                    self.timestep_shift = timestep_shift
         
     | 
| 57 | 
         
            +
                    self.base_shift = base_shift
         
     | 
| 58 | 
         
            +
                    self.max_shift = max_shift
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    self.vae_config = vae_config
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                
         
     | 
    	
        merges.txt
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        model-00001-of-00003.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:70d8644198c2531f6466b3aea7e34953fb9c5b67e04c71642f44a7b1e062b01b
         
     | 
| 3 | 
         
            +
            size 4061178424
         
     | 
    	
        model-00002-of-00003.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:c099dec97f53b4a18c79a206a6e2eed0893c4e8c91c701e995d85cdf53012d5f
         
     | 
| 3 | 
         
            +
            size 4973530020
         
     | 
    	
        model-00003-of-00003.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:744d9cdf3f96e946dd761454e1743992ef79b7629e2a285e086678c186db3506
         
     | 
| 3 | 
         
            +
            size 1212709496
         
     | 
    	
        model.safetensors.index.json
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        modeling_aimv2.py
    ADDED
    
    | 
         @@ -0,0 +1,385 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # adapted from https://huggingface.co/apple/aimv2-huge-patch14-448 (modification: add gradient checkpoint support)
         
     | 
| 2 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from torch import nn
         
     | 
| 6 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 7 | 
         
            +
            from transformers.modeling_outputs import BaseModelOutputWithNoAttention
         
     | 
| 8 | 
         
            +
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 9 | 
         
            +
            from flash_attn.layers.rotary import apply_rotary_emb
         
     | 
| 10 | 
         
            +
            from flash_attn import flash_attn_varlen_func
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from .configuration_aimv2 import AIMv2Config
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            __all__ = ["AIMv2Model"]
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class RMSNorm(nn.Module):
         
     | 
| 19 | 
         
            +
                def __init__(self, dim: int, eps: float = 1e-6):
         
     | 
| 20 | 
         
            +
                    super().__init__()
         
     | 
| 21 | 
         
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         
     | 
| 22 | 
         
            +
                    self.eps = eps
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 25 | 
         
            +
                    output = self._norm(x.float()).type_as(x)
         
     | 
| 26 | 
         
            +
                    return output * self.weight
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def extra_repr(self) -> str:
         
     | 
| 29 | 
         
            +
                    return f"{tuple(self.weight.shape)}, eps={self.eps}"
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def _norm(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 32 | 
         
            +
                    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            class AIMv2SwiGLUFFN(nn.Module):
         
     | 
| 36 | 
         
            +
                def __init__(self, config: AIMv2Config):
         
     | 
| 37 | 
         
            +
                    super().__init__()
         
     | 
| 38 | 
         
            +
                    hidden_features = config.intermediate_size
         
     | 
| 39 | 
         
            +
                    in_features = config.hidden_size
         
     | 
| 40 | 
         
            +
                    bias = config.use_bias
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
         
     | 
| 43 | 
         
            +
                    self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
         
     | 
| 44 | 
         
            +
                    self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 47 | 
         
            +
                    x = F.silu(self.fc1(x)) * self.fc3(x)
         
     | 
| 48 | 
         
            +
                    x = self.fc2(x)
         
     | 
| 49 | 
         
            +
                    return x
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            # copied from qwen2.5-vl
         
     | 
| 53 | 
         
            +
            class VisionRotaryEmbedding(nn.Module):
         
     | 
| 54 | 
         
            +
                def __init__(self, dim: int, theta: float = 10000.0) -> None:
         
     | 
| 55 | 
         
            +
                    super().__init__()
         
     | 
| 56 | 
         
            +
                    inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
         
     | 
| 57 | 
         
            +
                    self.register_buffer("inv_freq", inv_freq, persistent=False)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def forward(self, seqlen: int) -> torch.Tensor:
         
     | 
| 60 | 
         
            +
                    seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
         
     | 
| 61 | 
         
            +
                    freqs = torch.outer(seq, self.inv_freq)
         
     | 
| 62 | 
         
            +
                    return freqs
         
     | 
| 63 | 
         
            +
                
         
     | 
| 64 | 
         
            +
            # Note: in qwen2-vl and qwen2.5-vl, 3d convolution is used.
         
     | 
| 65 | 
         
            +
            class AIMv2PatchEmbed(nn.Module):
         
     | 
| 66 | 
         
            +
                def __init__(self, config: AIMv2Config):
         
     | 
| 67 | 
         
            +
                    super().__init__()
         
     | 
| 68 | 
         
            +
                    self.config = config
         
     | 
| 69 | 
         
            +
                    self.proj = nn.Conv2d(
         
     | 
| 70 | 
         
            +
                        config.num_channels,
         
     | 
| 71 | 
         
            +
                        config.hidden_size,
         
     | 
| 72 | 
         
            +
                        kernel_size=(config.patch_size, config.patch_size),
         
     | 
| 73 | 
         
            +
                        stride=(config.patch_size, config.patch_size),
         
     | 
| 74 | 
         
            +
                    )
         
     | 
| 75 | 
         
            +
                    self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 78 | 
         
            +
                    x = x.view(-1, self.config.num_channels * self.config.temporal_patch_size, self.config.patch_size, self.config.patch_size)
         
     | 
| 79 | 
         
            +
                    x = self.proj(x).view(-1, self.config.hidden_size) #.flatten(2).transpose(1, 2) # token_len x hidden_size
         
     | 
| 80 | 
         
            +
                    x = self.norm(x)
         
     | 
| 81 | 
         
            +
                    return x
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            class AIMv2ViTPreprocessor(nn.Module):
         
     | 
| 84 | 
         
            +
                def __init__(self, config: AIMv2Config):
         
     | 
| 85 | 
         
            +
                    super().__init__()
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    num_patches = (config.image_size // config.patch_size) ** 2
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    self.patchifier = AIMv2PatchEmbed(config)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    self.preserve_original_pe = config.preserve_original_pe
         
     | 
| 92 | 
         
            +
                    self.hidden_stride = config.hidden_stride
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    if self.preserve_original_pe:
         
     | 
| 95 | 
         
            +
                        self.interpolate_pe_method = config.interpolate_pe_method
         
     | 
| 96 | 
         
            +
                        self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def forward(self, x: torch.Tensor, grid_thws: Optional[torch.Tensor] = None) -> torch.Tensor:
         
     | 
| 99 | 
         
            +
                    tokens = self.patchifier(x)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    if self.preserve_original_pe:
         
     | 
| 102 | 
         
            +
                        assert grid_thws is not None
         
     | 
| 103 | 
         
            +
                        pos_embed_new = torch.zeros_like(tokens)
         
     | 
| 104 | 
         
            +
                        if self.interpolate_pe_method == 'one_dim':
         
     | 
| 105 | 
         
            +
                            pos_embed = self.pos_embed.transpose(1,2).to(tokens.device)
         
     | 
| 106 | 
         
            +
                        elif self.interpolate_pe_method == 'two_dim':
         
     | 
| 107 | 
         
            +
                            ori_h = ori_w = int(self.pos_embed.shape[1] ** 0.5)
         
     | 
| 108 | 
         
            +
                            pos_embed = self.pos_embed.reshape(1, ori_h, ori_w, -1).permute(0,3,1,2)
         
     | 
| 109 | 
         
            +
                        else:
         
     | 
| 110 | 
         
            +
                            raise TypeError("The interpolation method for pe should be one_dim, two_dim.")
         
     | 
| 111 | 
         
            +
                        cnt = 0
         
     | 
| 112 | 
         
            +
                        for t, h, w in grid_thws:
         
     | 
| 113 | 
         
            +
                            num_patches = h * w
         
     | 
| 114 | 
         
            +
                            thw = t * h * w
         
     | 
| 115 | 
         
            +
                            if self.interpolate_pe_method == 'one_dim':
         
     | 
| 116 | 
         
            +
                                pe = F.interpolate(pos_embed, size=num_patches, mode='linear', align_corners=False).transpose(1,2)
         
     | 
| 117 | 
         
            +
                            elif self.interpolate_pe_method == 'two_dim':
         
     | 
| 118 | 
         
            +
                                # 1, 1024, 32, 32
         
     | 
| 119 | 
         
            +
                                pe = F.interpolate(pos_embed, size=(h,w), mode='bicubic', align_corners=False)
         
     | 
| 120 | 
         
            +
                                # 1, 1024, 1024
         
     | 
| 121 | 
         
            +
                                pe = pe.permute(0,2,3,1).reshape(1, h*w, -1)
         
     | 
| 122 | 
         
            +
                            # 1024, 1024
         
     | 
| 123 | 
         
            +
                            pe = pe[0].repeat(t,1)
         
     | 
| 124 | 
         
            +
                            # 1, 16, 2, 16, 2, 1024
         
     | 
| 125 | 
         
            +
                            pe = pe.reshape(t, h//self.hidden_stride, self.hidden_stride, w//self.hidden_stride, self.hidden_stride, -1)
         
     | 
| 126 | 
         
            +
                            # 1024, 1024
         
     | 
| 127 | 
         
            +
                            pe = pe.permute(0,1,3,2,4,5).reshape(thw,-1)
         
     | 
| 128 | 
         
            +
                            pos_embed_new[cnt:cnt+thw] = pe
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                            cnt += thw
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                        tokens = tokens + pos_embed_new
         
     | 
| 133 | 
         
            +
                    return tokens
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            # copied from qwen2.5-vl
         
     | 
| 136 | 
         
            +
            def apply_rotary_pos_emb_flashatt(
         
     | 
| 137 | 
         
            +
                q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
         
     | 
| 138 | 
         
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 139 | 
         
            +
                cos = cos.chunk(2, dim=-1)[0].contiguous()
         
     | 
| 140 | 
         
            +
                sin = sin.chunk(2, dim=-1)[0].contiguous()
         
     | 
| 141 | 
         
            +
                q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
         
     | 
| 142 | 
         
            +
                k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
         
     | 
| 143 | 
         
            +
                return q_embed, k_embed
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            class AIMv2FlashAttention2(nn.Module):
         
     | 
| 146 | 
         
            +
                def __init__(self, config: AIMv2Config) -> None:
         
     | 
| 147 | 
         
            +
                    super().__init__()
         
     | 
| 148 | 
         
            +
                    dim = config.hidden_size
         
     | 
| 149 | 
         
            +
                    self.num_heads = config.num_attention_heads
         
     | 
| 150 | 
         
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
         
     | 
| 151 | 
         
            +
                    self.proj = nn.Linear(dim, dim, bias=config.use_bias)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    self.use_rope = not config.disable_rope
         
     | 
| 154 | 
         
            +
                    
         
     | 
| 155 | 
         
            +
                def forward(
         
     | 
| 156 | 
         
            +
                    self,
         
     | 
| 157 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 158 | 
         
            +
                    cu_seqlens: torch.Tensor,
         
     | 
| 159 | 
         
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         
     | 
| 160 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    seq_length = hidden_states.shape[0]
         
     | 
| 163 | 
         
            +
                    q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
         
     | 
| 164 | 
         
            +
                    if self.use_rope:
         
     | 
| 165 | 
         
            +
                        cos, sin = position_embeddings
         
     | 
| 166 | 
         
            +
                        q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
         
     | 
| 167 | 
         
            +
                        q = q.squeeze(0)
         
     | 
| 168 | 
         
            +
                        k = k.squeeze(0)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
         
     | 
| 171 | 
         
            +
                    attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
         
     | 
| 172 | 
         
            +
                        seq_length, -1
         
     | 
| 173 | 
         
            +
                    )
         
     | 
| 174 | 
         
            +
                    attn_output = self.proj(attn_output)
         
     | 
| 175 | 
         
            +
                    return attn_output
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
            class AIMv2Block(nn.Module):
         
     | 
| 178 | 
         
            +
                def __init__(self, config: AIMv2Config):
         
     | 
| 179 | 
         
            +
                    super().__init__()
         
     | 
| 180 | 
         
            +
                    self.attn = AIMv2FlashAttention2(config)
         
     | 
| 181 | 
         
            +
                    self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 182 | 
         
            +
                    self.mlp = AIMv2SwiGLUFFN(config)
         
     | 
| 183 | 
         
            +
                    self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def forward(
         
     | 
| 186 | 
         
            +
                    self, x: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor
         
     | 
| 187 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 188 | 
         
            +
                    x = x + self.attn(self.norm_1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
         
     | 
| 189 | 
         
            +
                    x = x + self.mlp(self.norm_2(x))
         
     | 
| 190 | 
         
            +
                    return x
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            class AIMv2Transformer(nn.Module):
         
     | 
| 194 | 
         
            +
                def __init__(self, config: AIMv2Config):
         
     | 
| 195 | 
         
            +
                    super().__init__()
         
     | 
| 196 | 
         
            +
                    self.blocks = nn.ModuleList(
         
     | 
| 197 | 
         
            +
                        [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
         
     | 
| 198 | 
         
            +
                    )
         
     | 
| 199 | 
         
            +
                    self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 200 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    self.rotary_pos_emb = VisionRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
         
     | 
| 203 | 
         
            +
                    
         
     | 
| 204 | 
         
            +
                    self.hidden_stride = config.hidden_stride
         
     | 
| 205 | 
         
            +
                    self.patch_size = config.patch_size
         
     | 
| 206 | 
         
            +
                    self.window_size = config.window_size
         
     | 
| 207 | 
         
            +
                    self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
         
     | 
| 208 | 
         
            +
                    
         
     | 
| 209 | 
         
            +
                    self.fullatt_block_indexes = config.fullatt_block_indexes
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                # copied from qwen2.5_vl
         
     | 
| 212 | 
         
            +
                def rot_pos_emb(self, grid_thw):
         
     | 
| 213 | 
         
            +
                    pos_ids = []
         
     | 
| 214 | 
         
            +
                    for t, h, w in grid_thw:
         
     | 
| 215 | 
         
            +
                        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
         
     | 
| 216 | 
         
            +
                        hpos_ids = hpos_ids.reshape(
         
     | 
| 217 | 
         
            +
                            h // self.hidden_stride,
         
     | 
| 218 | 
         
            +
                            self.hidden_stride,
         
     | 
| 219 | 
         
            +
                            w // self.hidden_stride,
         
     | 
| 220 | 
         
            +
                            self.hidden_stride,
         
     | 
| 221 | 
         
            +
                        )
         
     | 
| 222 | 
         
            +
                        hpos_ids = hpos_ids.permute(0, 2, 1, 3)
         
     | 
| 223 | 
         
            +
                        hpos_ids = hpos_ids.flatten()
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                        wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
         
     | 
| 226 | 
         
            +
                        wpos_ids = wpos_ids.reshape(
         
     | 
| 227 | 
         
            +
                            h // self.hidden_stride,
         
     | 
| 228 | 
         
            +
                            self.hidden_stride,
         
     | 
| 229 | 
         
            +
                            w // self.hidden_stride,
         
     | 
| 230 | 
         
            +
                            self.hidden_stride,
         
     | 
| 231 | 
         
            +
                        )
         
     | 
| 232 | 
         
            +
                        wpos_ids = wpos_ids.permute(0, 2, 1, 3)
         
     | 
| 233 | 
         
            +
                        wpos_ids = wpos_ids.flatten()
         
     | 
| 234 | 
         
            +
                        pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
         
     | 
| 235 | 
         
            +
                    pos_ids = torch.cat(pos_ids, dim=0)
         
     | 
| 236 | 
         
            +
                    max_grid_size = grid_thw[:, 1:].max()
         
     | 
| 237 | 
         
            +
                    rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
         
     | 
| 238 | 
         
            +
                    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
         
     | 
| 239 | 
         
            +
                    return rotary_pos_emb
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                def get_window_index(self, grid_thw):
         
     | 
| 242 | 
         
            +
                    window_index: list = []
         
     | 
| 243 | 
         
            +
                    cu_window_seqlens: list = [0]
         
     | 
| 244 | 
         
            +
                    window_index_id = 0
         
     | 
| 245 | 
         
            +
                    vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size # patch (after merge) number in each window
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    for grid_t, grid_h, grid_w in grid_thw:
         
     | 
| 248 | 
         
            +
                        llm_grid_h, llm_grid_w = (
         
     | 
| 249 | 
         
            +
                            grid_h // self.hidden_stride, # number of patch after merge
         
     | 
| 250 | 
         
            +
                            grid_w // self.hidden_stride,
         
     | 
| 251 | 
         
            +
                        )
         
     | 
| 252 | 
         
            +
                        index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
         
     | 
| 253 | 
         
            +
                        pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
         
     | 
| 254 | 
         
            +
                        pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
         
     | 
| 255 | 
         
            +
                        num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
         
     | 
| 256 | 
         
            +
                        num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
         
     | 
| 257 | 
         
            +
                        index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
         
     | 
| 258 | 
         
            +
                        index_padded = index_padded.reshape(
         
     | 
| 259 | 
         
            +
                            grid_t,
         
     | 
| 260 | 
         
            +
                            num_windows_h,
         
     | 
| 261 | 
         
            +
                            vit_merger_window_size,
         
     | 
| 262 | 
         
            +
                            num_windows_w,
         
     | 
| 263 | 
         
            +
                            vit_merger_window_size,
         
     | 
| 264 | 
         
            +
                        )
         
     | 
| 265 | 
         
            +
                        index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
         
     | 
| 266 | 
         
            +
                            grid_t,
         
     | 
| 267 | 
         
            +
                            num_windows_h * num_windows_w,
         
     | 
| 268 | 
         
            +
                            vit_merger_window_size,
         
     | 
| 269 | 
         
            +
                            vit_merger_window_size,
         
     | 
| 270 | 
         
            +
                        )
         
     | 
| 271 | 
         
            +
                        seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
         
     | 
| 272 | 
         
            +
                        index_padded = index_padded.reshape(-1)
         
     | 
| 273 | 
         
            +
                        index_new = index_padded[index_padded != -100]
         
     | 
| 274 | 
         
            +
                        window_index.append(index_new + window_index_id)
         
     | 
| 275 | 
         
            +
                        cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
         
     | 
| 276 | 
         
            +
                        cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
         
     | 
| 277 | 
         
            +
                        window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
         
     | 
| 278 | 
         
            +
                    window_index = torch.cat(window_index, dim=0)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    return window_index, cu_window_seqlens
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                def forward(
         
     | 
| 283 | 
         
            +
                    self,
         
     | 
| 284 | 
         
            +
                    tokens: torch.Tensor,
         
     | 
| 285 | 
         
            +
                    grid_thws: torch.Tensor,
         
     | 
| 286 | 
         
            +
                    output_hidden_states: bool = False,
         
     | 
| 287 | 
         
            +
                ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
         
     | 
| 288 | 
         
            +
                    # RoPE, modified from qwen2.5_vl
         
     | 
| 289 | 
         
            +
                    rotary_pos_emb = self.rot_pos_emb(grid_thws)
         
     | 
| 290 | 
         
            +
                    window_index, cu_window_seqlens = self.get_window_index(grid_thws)
         
     | 
| 291 | 
         
            +
                    cu_window_seqlens = torch.tensor(
         
     | 
| 292 | 
         
            +
                        cu_window_seqlens,
         
     | 
| 293 | 
         
            +
                        device=tokens.device,
         
     | 
| 294 | 
         
            +
                        dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
         
     | 
| 295 | 
         
            +
                    )
         
     | 
| 296 | 
         
            +
                    cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    seq_len, _ = tokens.size()
         
     | 
| 299 | 
         
            +
                    tokens = tokens.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
         
     | 
| 300 | 
         
            +
                    tokens = tokens[window_index, :, :]
         
     | 
| 301 | 
         
            +
                    tokens = tokens.reshape(seq_len, -1)
         
     | 
| 302 | 
         
            +
                    rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
         
     | 
| 303 | 
         
            +
                    rotary_pos_emb = rotary_pos_emb[window_index, :, :]
         
     | 
| 304 | 
         
            +
                    rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
         
     | 
| 305 | 
         
            +
                    emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
         
     | 
| 306 | 
         
            +
                    position_embeddings = (emb.cos(), emb.sin())
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(
         
     | 
| 309 | 
         
            +
                        dim=0,
         
     | 
| 310 | 
         
            +
                        # Select dtype based on the following factors:
         
     | 
| 311 | 
         
            +
                        #  - FA2 requires that cu_seqlens_q must have dtype int32
         
     | 
| 312 | 
         
            +
                        #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
         
     | 
| 313 | 
         
            +
                        # See https://github.com/huggingface/transformers/pull/34852 for more information
         
     | 
| 314 | 
         
            +
                        dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
         
     | 
| 315 | 
         
            +
                    )
         
     | 
| 316 | 
         
            +
                    cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    reverse_indices = torch.argsort(window_index)
         
     | 
| 319 | 
         
            +
                    
         
     | 
| 320 | 
         
            +
                    hidden_states = () if output_hidden_states else None
         
     | 
| 321 | 
         
            +
                    for index, block in enumerate(self.blocks):
         
     | 
| 322 | 
         
            +
                        if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes:
         
     | 
| 323 | 
         
            +
                            cu_seqlens_tmp = cu_seqlens
         
     | 
| 324 | 
         
            +
                        else:
         
     | 
| 325 | 
         
            +
                            cu_seqlens_tmp = cu_window_seqlens
         
     | 
| 326 | 
         
            +
                        if self.gradient_checkpointing and self.training:
         
     | 
| 327 | 
         
            +
                            tokens = self._gradient_checkpointing_func(block.__call__, tokens, cu_seqlens_tmp, position_embeddings)
         
     | 
| 328 | 
         
            +
                        else:
         
     | 
| 329 | 
         
            +
                            tokens = block(tokens, cu_seqlens_tmp, position_embeddings)
         
     | 
| 330 | 
         
            +
                        if output_hidden_states:
         
     | 
| 331 | 
         
            +
                            tokens_ = tokens.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
         
     | 
| 332 | 
         
            +
                            hidden_states += (tokens_[reverse_indices,:].reshape(seq_len, -1),)
         
     | 
| 333 | 
         
            +
                    tokens = self.post_trunk_norm(tokens)
         
     | 
| 334 | 
         
            +
                    tokens = tokens.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
         
     | 
| 335 | 
         
            +
                    tokens = tokens[reverse_indices,:].reshape(seq_len, -1)
         
     | 
| 336 | 
         
            +
                    
         
     | 
| 337 | 
         
            +
                    return tokens, hidden_states
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
            class AIMv2PretrainedModel(PreTrainedModel):
         
     | 
| 341 | 
         
            +
                config_class = AIMv2Config
         
     | 
| 342 | 
         
            +
                base_model_prefix = "aimv2"
         
     | 
| 343 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 344 | 
         
            +
                main_input_name = "pixel_values"
         
     | 
| 345 | 
         
            +
                _no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
         
     | 
| 346 | 
         
            +
                _supports_sdpa = True
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
            class AIMv2Model(AIMv2PretrainedModel):
         
     | 
| 350 | 
         
            +
                def __init__(self, config: AIMv2Config):
         
     | 
| 351 | 
         
            +
                    super().__init__(config)
         
     | 
| 352 | 
         
            +
                    self.preprocessor = AIMv2ViTPreprocessor(config)
         
     | 
| 353 | 
         
            +
                    self.trunk = AIMv2Transformer(config)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                def forward(
         
     | 
| 356 | 
         
            +
                    self,
         
     | 
| 357 | 
         
            +
                    pixel_values: torch.Tensor,
         
     | 
| 358 | 
         
            +
                    grid_thws: torch.Tensor,
         
     | 
| 359 | 
         
            +
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 360 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 361 | 
         
            +
                ) -> Union[
         
     | 
| 362 | 
         
            +
                    Tuple[torch.Tensor],
         
     | 
| 363 | 
         
            +
                    Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
         
     | 
| 364 | 
         
            +
                    BaseModelOutputWithNoAttention,
         
     | 
| 365 | 
         
            +
                ]:
         
     | 
| 366 | 
         
            +
                    if output_hidden_states is None:
         
     | 
| 367 | 
         
            +
                        output_hidden_states = self.config.output_hidden_states
         
     | 
| 368 | 
         
            +
                    if return_dict is None:
         
     | 
| 369 | 
         
            +
                        return_dict = self.config.use_return_dict
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    x = self.preprocessor(pixel_values, grid_thws=grid_thws)
         
     | 
| 372 | 
         
            +
                    
         
     | 
| 373 | 
         
            +
                    x, hidden_states = self.trunk(
         
     | 
| 374 | 
         
            +
                        x, grid_thws=grid_thws, output_hidden_states=output_hidden_states
         
     | 
| 375 | 
         
            +
                    )
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    if not return_dict:
         
     | 
| 378 | 
         
            +
                        res = (x,)
         
     | 
| 379 | 
         
            +
                        res += (hidden_states,) if output_hidden_states else ()
         
     | 
| 380 | 
         
            +
                        return res
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                    return BaseModelOutputWithNoAttention(
         
     | 
| 383 | 
         
            +
                        last_hidden_state=x,
         
     | 
| 384 | 
         
            +
                        hidden_states=hidden_states,
         
     | 
| 385 | 
         
            +
                    )
         
     | 
    	
        modeling_ovis_u1.py
    ADDED
    
    | 
         @@ -0,0 +1,921 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import logging
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            from datetime import datetime
         
     | 
| 4 | 
         
            +
            from importlib import import_module
         
     | 
| 5 | 
         
            +
            from typing import List, Union, Optional, Dict
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import PIL.Image
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from torch import Tensor
         
     | 
| 11 | 
         
            +
            from torch.nn import init
         
     | 
| 12 | 
         
            +
            from torch.nn.functional import softmax, gumbel_softmax, pad
         
     | 
| 13 | 
         
            +
            from torchvision import transforms
         
     | 
| 14 | 
         
            +
            import transformers
         
     | 
| 15 | 
         
            +
            from transformers import AutoImageProcessor
         
     | 
| 16 | 
         
            +
            from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
         
     | 
| 17 | 
         
            +
            from transformers.generation.utils import GenerateOutput
         
     | 
| 18 | 
         
            +
            from transformers import CLIPImageProcessor
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .modeling_aimv2 import AIMv2Model
         
     | 
| 21 | 
         
            +
            from .configuration_ovis_u1 import BaseVisualTokenizerConfig, Aimv2VisualTokenizerConfig
         
     | 
| 22 | 
         
            +
            from .configuration_ovis_u1 import OvisU1Config, ConversationFormatter
         
     | 
| 23 | 
         
            +
            from .configuration_ovis_u1 import IGNORE_ID, IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS, IMAGE_TOKEN_ID, VIDEO_TOKEN_ID
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 26 | 
         
            +
            #                            Visual Tokenizer
         
     | 
| 27 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 28 | 
         
            +
            class BaseVisualTokenizer(PreTrainedModel):
         
     | 
| 29 | 
         
            +
                base_model_prefix = "backbone"
         
     | 
| 30 | 
         
            +
                main_input_name = None
         
     | 
| 31 | 
         
            +
                _image_processor_class = None
         
     | 
| 32 | 
         
            +
                _image_processor_kwargs = {}
         
     | 
| 33 | 
         
            +
                _backbone_class = None
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs):
         
     | 
| 36 | 
         
            +
                    super().__init__(config, *inputs, **kwargs)
         
     | 
| 37 | 
         
            +
                    if kwargs.get('train_from_scratch'):
         
     | 
| 38 | 
         
            +
                        # for key in self._image_processor_kwargs.keys():
         
     | 
| 39 | 
         
            +
                        #     self._image_processor_kwargs[key] = getattr(self.config, key, self._image_processor_kwargs[key])
         
     | 
| 40 | 
         
            +
                        image_processor = self._image_processor_class.from_pretrained(kwargs['backbone_name_or_path'],
         
     | 
| 41 | 
         
            +
                                                                                           **self._image_processor_kwargs)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                        self.backbone = self._backbone_class.from_pretrained(kwargs['backbone_name_or_path'], **self.config.backbone_kwargs)
         
     | 
| 44 | 
         
            +
                        self.config.backbone_config = self.backbone.config
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                        config = image_processor.to_dict()
         
     | 
| 47 | 
         
            +
                        if getattr(self.config, 'image_processor_new_kwargs', None) is not None:
         
     | 
| 48 | 
         
            +
                            for key in self.config.image_processor_new_kwargs.keys():
         
     | 
| 49 | 
         
            +
                                config[key] = self.config.image_processor_new_kwargs[key]
         
     | 
| 50 | 
         
            +
                        if 'patch_size' not in config:
         
     | 
| 51 | 
         
            +
                            assert getattr(self.backbone.config, 'patch_size'), "Patch size must be set."
         
     | 
| 52 | 
         
            +
                            config['patch_size'] = self.backbone.config.patch_size
         
     | 
| 53 | 
         
            +
                        self.image_processor = self._image_processor_class.from_dict(config)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    else:
         
     | 
| 56 | 
         
            +
                        self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path'])
         
     | 
| 57 | 
         
            +
                        self.backbone = AutoModel.from_config(self.config.backbone_config)
         
     | 
| 58 | 
         
            +
                    head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS)  # reserved tokens for IMAGE_INDICATORS
         
     | 
| 59 | 
         
            +
                    self.head = torch.nn.Sequential(
         
     | 
| 60 | 
         
            +
                        torch.nn.Linear(
         
     | 
| 61 | 
         
            +
                            self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim,
         
     | 
| 62 | 
         
            +
                            bias=False
         
     | 
| 63 | 
         
            +
                        ),
         
     | 
| 64 | 
         
            +
                        torch.nn.LayerNorm(head_dim)
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
                    assert all((self.image_processor.do_resize,
         
     | 
| 67 | 
         
            +
                                not getattr(self.image_processor, 'do_center_crop', False),
         
     | 
| 68 | 
         
            +
                                self.image_processor.do_rescale,
         
     | 
| 69 | 
         
            +
                                self.image_processor.do_normalize
         
     | 
| 70 | 
         
            +
                                )), f"image_processor `{self.image_processor}` is not supported currently"
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def get_backbone(self):
         
     | 
| 73 | 
         
            +
                    return self.backbone
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def get_monitor_tensors(self):
         
     | 
| 76 | 
         
            +
                    raise NotImplementedError
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def get_image_processor(self):
         
     | 
| 79 | 
         
            +
                    return self.image_processor
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def mock_input(self):
         
     | 
| 82 | 
         
            +
                    height, width = self.get_image_size()
         
     | 
| 83 | 
         
            +
                    return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1))
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                def get_head(self):
         
     | 
| 86 | 
         
            +
                    return self.head
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def get_image_size(self):
         
     | 
| 89 | 
         
            +
                    raise NotImplementedError
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                @staticmethod
         
     | 
| 92 | 
         
            +
                def construct_image_placeholders(grid, data_type='image'):
         
     | 
| 93 | 
         
            +
                    if data_type == 'image':
         
     | 
| 94 | 
         
            +
                        image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]]
         
     | 
| 95 | 
         
            +
                    elif data_type == 'video':
         
     | 
| 96 | 
         
            +
                        image_placeholders = [IMAGE_INDICATOR_IDS[2], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[2]]
         
     | 
| 97 | 
         
            +
                    else:
         
     | 
| 98 | 
         
            +
                        raise TypeError
         
     | 
| 99 | 
         
            +
                    
         
     | 
| 100 | 
         
            +
                    return image_placeholders
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                @staticmethod
         
     | 
| 103 | 
         
            +
                def _partition(img_size, grid):
         
     | 
| 104 | 
         
            +
                    w, h = img_size
         
     | 
| 105 | 
         
            +
                    row_height = h // grid[0]
         
     | 
| 106 | 
         
            +
                    col_width = w // grid[1]
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    partition = []
         
     | 
| 109 | 
         
            +
                    for row in range(grid[0]):
         
     | 
| 110 | 
         
            +
                        for col in range(grid[1]):
         
     | 
| 111 | 
         
            +
                            left = col * col_width
         
     | 
| 112 | 
         
            +
                            upper = row * row_height
         
     | 
| 113 | 
         
            +
                            right = w if col == grid[1] - 1 else (col + 1) * col_width
         
     | 
| 114 | 
         
            +
                            lower = h if row == grid[0] - 1 else (row + 1) * row_height
         
     | 
| 115 | 
         
            +
                            partition.append((left, upper, right, lower))
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    return partition
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                @staticmethod
         
     | 
| 120 | 
         
            +
                def get_best_grid(img_size, side, max_partition, covering_threshold):
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    def _covering_area(left, upper, right, lower, side):
         
     | 
| 123 | 
         
            +
                        w = right - left
         
     | 
| 124 | 
         
            +
                        h = lower - upper
         
     | 
| 125 | 
         
            +
                        w, h = max(w, h), min(w, h)
         
     | 
| 126 | 
         
            +
                        if w > side:
         
     | 
| 127 | 
         
            +
                            h = h / w * side
         
     | 
| 128 | 
         
            +
                            w = side
         
     | 
| 129 | 
         
            +
                        return w * h
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    img_area = img_size[0] * img_size[1]
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    candidate_grids = []
         
     | 
| 134 | 
         
            +
                    for i in range(1, max_partition + 1):
         
     | 
| 135 | 
         
            +
                        for j in range(1, max_partition + 1):
         
     | 
| 136 | 
         
            +
                            if i * j <= max_partition:
         
     | 
| 137 | 
         
            +
                                candidate_grids.append((i, j))
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    all_grids = []
         
     | 
| 140 | 
         
            +
                    good_grids = []
         
     | 
| 141 | 
         
            +
                    for grid in candidate_grids:
         
     | 
| 142 | 
         
            +
                        partition = BaseVisualTokenizer._partition(img_size, grid)
         
     | 
| 143 | 
         
            +
                        covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
         
     | 
| 144 | 
         
            +
                        assert covering_ratio <= 1.0
         
     | 
| 145 | 
         
            +
                        all_grids.append((grid, covering_ratio))
         
     | 
| 146 | 
         
            +
                        if covering_ratio > covering_threshold:
         
     | 
| 147 | 
         
            +
                            good_grids.append((grid, covering_ratio))
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    if len(good_grids) > 0:
         
     | 
| 150 | 
         
            +
                        # pick the good partition with minimum #sub_images and break the tie using covering_ratio
         
     | 
| 151 | 
         
            +
                        return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
         
     | 
| 152 | 
         
            +
                    else:
         
     | 
| 153 | 
         
            +
                        # pick the partition with maximum covering_ratio and break the tie using #sub_images
         
     | 
| 154 | 
         
            +
                        return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def preprocess_image(self, image: PIL.Image.Image, max_partition=4, covering_threshold=0.9, convert_to_rgb=True):
         
     | 
| 157 | 
         
            +
                    def _preprocess(img: PIL.Image.Image, side):
         
     | 
| 158 | 
         
            +
                        # first resize and preprocess
         
     | 
| 159 | 
         
            +
                        w, h = img.size
         
     | 
| 160 | 
         
            +
                        if w == h:
         
     | 
| 161 | 
         
            +
                            new_width = new_height = side
         
     | 
| 162 | 
         
            +
                        elif w > h:
         
     | 
| 163 | 
         
            +
                            new_width = side
         
     | 
| 164 | 
         
            +
                            new_height = int(h / w * new_width)
         
     | 
| 165 | 
         
            +
                        else:
         
     | 
| 166 | 
         
            +
                            new_height = side
         
     | 
| 167 | 
         
            +
                            new_width = int(w / h * new_height)
         
     | 
| 168 | 
         
            +
                        new_size = dict(height=new_height, width=new_width)
         
     | 
| 169 | 
         
            +
                        pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values']
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                        # then pad to square
         
     | 
| 172 | 
         
            +
                        square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
         
     | 
| 173 | 
         
            +
                        new_height, new_width = pixel_values.shape[2:]
         
     | 
| 174 | 
         
            +
                        if new_height == new_width:
         
     | 
| 175 | 
         
            +
                            square_values[:, :, :, :] = pixel_values
         
     | 
| 176 | 
         
            +
                        elif new_height > new_width:
         
     | 
| 177 | 
         
            +
                            from_index = (side - new_width) // 2
         
     | 
| 178 | 
         
            +
                            square_values[:, :, :, from_index:from_index + new_width] = pixel_values
         
     | 
| 179 | 
         
            +
                        else:
         
     | 
| 180 | 
         
            +
                            from_index = (side - new_height) // 2
         
     | 
| 181 | 
         
            +
                            square_values[:, :, from_index:from_index + new_height, :] = pixel_values
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                        return square_values
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    if convert_to_rgb and image.mode != 'RGB':
         
     | 
| 186 | 
         
            +
                        image = image.convert('RGB')
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    sides = self.get_image_size()
         
     | 
| 189 | 
         
            +
                    if sides[0] != sides[1]:
         
     | 
| 190 | 
         
            +
                        raise ValueError('get_image_size() returns non-square size')
         
     | 
| 191 | 
         
            +
                    side = sides[0]
         
     | 
| 192 | 
         
            +
                    grid = self.get_best_grid(image.size, side, max_partition, covering_threshold)
         
     | 
| 193 | 
         
            +
                    partition = self._partition(image.size, grid)
         
     | 
| 194 | 
         
            +
                    crops = [image.crop(p) for p in partition]
         
     | 
| 195 | 
         
            +
                    if len(crops) > 1:
         
     | 
| 196 | 
         
            +
                        crops.insert(0, image)
         
     | 
| 197 | 
         
            +
                    pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
         
     | 
| 198 | 
         
            +
                    image_placeholders = self.construct_image_placeholders(grid)
         
     | 
| 199 | 
         
            +
                    return pixel_values, image_placeholders
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                def get_backbone_layer(self, index):
         
     | 
| 202 | 
         
            +
                    if 'aimv2' in self.config.model_type:
         
     | 
| 203 | 
         
            +
                        return self.backbone.trunk.blocks[index]
         
     | 
| 204 | 
         
            +
                    else:
         
     | 
| 205 | 
         
            +
                        return self.backbone.vision_model.encoder.layers[index]
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                def tokenize(self, logits):
         
     | 
| 208 | 
         
            +
                    def st_argmax(y_soft, dim):  # straight-through softmax
         
     | 
| 209 | 
         
            +
                        index = y_soft.max(dim, keepdim=True)[1]
         
     | 
| 210 | 
         
            +
                        y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
         
     | 
| 211 | 
         
            +
                        ret = y_hard - y_soft.detach() + y_soft
         
     | 
| 212 | 
         
            +
                        return ret
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    if self.config.tokenize_function == 'softmax':
         
     | 
| 215 | 
         
            +
                        tokens = softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
         
     | 
| 216 | 
         
            +
                    elif self.config.tokenize_function == 'gumbel_argmax':
         
     | 
| 217 | 
         
            +
                        tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
         
     | 
| 218 | 
         
            +
                    elif self.config.tokenize_function == 'st_argmax':
         
     | 
| 219 | 
         
            +
                        tokens = st_argmax(logits, dim=-1)
         
     | 
| 220 | 
         
            +
                    else:
         
     | 
| 221 | 
         
            +
                        raise ValueError(
         
     | 
| 222 | 
         
            +
                            f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}')
         
     | 
| 223 | 
         
            +
                    return tokens
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                def encode(self, pixel_values):
         
     | 
| 226 | 
         
            +
                    output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True)
         
     | 
| 227 | 
         
            +
                    features = output.hidden_states[-1]
         
     | 
| 228 | 
         
            +
                    if self.config.drop_cls_token:
         
     | 
| 229 | 
         
            +
                        features = features[:, 1:, :]
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    # merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length
         
     | 
| 232 | 
         
            +
                    # e.g., for hidden_stride=3, this leads to a token length reduction: 729 -> 81 for siglip
         
     | 
| 233 | 
         
            +
                    if self.config.hidden_stride > 1:
         
     | 
| 234 | 
         
            +
                        n, l, d = features.shape  # this `d` maybe different from the above `d
         
     | 
| 235 | 
         
            +
                        sqrt_l = int(l ** 0.5)
         
     | 
| 236 | 
         
            +
                        assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square."
         
     | 
| 237 | 
         
            +
                        features = features.reshape(n, sqrt_l, sqrt_l, d)
         
     | 
| 238 | 
         
            +
                        pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride
         
     | 
| 239 | 
         
            +
                        features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
         
     | 
| 240 | 
         
            +
                        sqrt_l += pl
         
     | 
| 241 | 
         
            +
                        features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride,
         
     | 
| 242 | 
         
            +
                                                    sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d)
         
     | 
| 243 | 
         
            +
                        features = features.permute(0, 1, 3, 2, 4, 5)  # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
         
     | 
| 244 | 
         
            +
                        features = features.flatten(3)  # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
         
     | 
| 245 | 
         
            +
                        features = features.reshape(
         
     | 
| 246 | 
         
            +
                            n, -1, self.config.hidden_stride * self.config.hidden_stride * d)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    return features
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                def forward(self, pixel_values) -> torch.Tensor:  # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
         
     | 
| 251 | 
         
            +
                    features = self.encode(pixel_values)
         
     | 
| 252 | 
         
            +
                    logits = self.head(features)
         
     | 
| 253 | 
         
            +
                    tokens = self.tokenize(logits)
         
     | 
| 254 | 
         
            +
                    # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after
         
     | 
| 255 | 
         
            +
                    # which, tokens' shape should become [BatchSize, #Token, VocabSize]
         
     | 
| 256 | 
         
            +
                    batch_size, token_len, _ = tokens.shape
         
     | 
| 257 | 
         
            +
                    padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)),
         
     | 
| 258 | 
         
            +
                                                 dtype=tokens.dtype,
         
     | 
| 259 | 
         
            +
                                                 device=tokens.device,
         
     | 
| 260 | 
         
            +
                                                 layout=tokens.layout,
         
     | 
| 261 | 
         
            +
                                                 requires_grad=False)
         
     | 
| 262 | 
         
            +
                    tokens = torch.cat((tokens, padding_tensor), dim=2)
         
     | 
| 263 | 
         
            +
                    return tokens
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
            class Aimv2VisualTokenizer(BaseVisualTokenizer):
         
     | 
| 266 | 
         
            +
                config_class = Aimv2VisualTokenizerConfig
         
     | 
| 267 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 268 | 
         
            +
                _no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
         
     | 
| 269 | 
         
            +
                _image_processor_class = CLIPImageProcessor
         
     | 
| 270 | 
         
            +
                _image_processor_kwargs = dict(do_center_crop=False, crop_size={'height': -1, 'width': -1}, size={'shortest_edge':-1})
         
     | 
| 271 | 
         
            +
                _backbone_class = AIMv2Model
         
     | 
| 272 | 
         
            +
                
         
     | 
| 273 | 
         
            +
                # Copied from qwen2_vl
         
     | 
| 274 | 
         
            +
                def smart_resize(self, 
         
     | 
| 275 | 
         
            +
                    height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
         
     | 
| 276 | 
         
            +
                ):
         
     | 
| 277 | 
         
            +
                    """Rescales the image so that the following conditions are met:
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    1. Both dimensions (height and width) are divisible by 'factor'.
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    3. The aspect ratio of the image is maintained as closely as possible.
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    """
         
     | 
| 286 | 
         
            +
                    
         
     | 
| 287 | 
         
            +
                    if height < factor or width < factor:
         
     | 
| 288 | 
         
            +
                        print(f"height:{height} or width:{width} must be larger than factor:{factor}")
         
     | 
| 289 | 
         
            +
                        if height < width:
         
     | 
| 290 | 
         
            +
                            width = round(factor/height*width)
         
     | 
| 291 | 
         
            +
                            height = factor
         
     | 
| 292 | 
         
            +
                        else:
         
     | 
| 293 | 
         
            +
                            height = round(factor/width*height)
         
     | 
| 294 | 
         
            +
                            width = factor
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    elif max(height, width) / min(height, width) > 200:
         
     | 
| 297 | 
         
            +
                        print(
         
     | 
| 298 | 
         
            +
                            f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
         
     | 
| 299 | 
         
            +
                        )
         
     | 
| 300 | 
         
            +
                        if height > width:
         
     | 
| 301 | 
         
            +
                            height = 200 * width
         
     | 
| 302 | 
         
            +
                        else:
         
     | 
| 303 | 
         
            +
                            width = 200 * height
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    h_bar = round(height / factor) * factor
         
     | 
| 306 | 
         
            +
                    w_bar = round(width / factor) * factor
         
     | 
| 307 | 
         
            +
                    if h_bar * w_bar > max_pixels:
         
     | 
| 308 | 
         
            +
                        beta = math.sqrt((height * width) / max_pixels)
         
     | 
| 309 | 
         
            +
                        h_bar = math.floor(height / beta / factor) * factor
         
     | 
| 310 | 
         
            +
                        w_bar = math.floor(width / beta / factor) * factor
         
     | 
| 311 | 
         
            +
                    elif h_bar * w_bar < min_pixels:
         
     | 
| 312 | 
         
            +
                        beta = math.sqrt(min_pixels / (height * width))
         
     | 
| 313 | 
         
            +
                        h_bar = math.ceil(height * beta / factor) * factor
         
     | 
| 314 | 
         
            +
                        w_bar = math.ceil(width * beta / factor) * factor
         
     | 
| 315 | 
         
            +
                    return h_bar, w_bar
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                def get_monitor_tensors(self):
         
     | 
| 318 | 
         
            +
                    return dict(
         
     | 
| 319 | 
         
            +
                        backbone_bottom=self.backbone.trunk.blocks[0].attn.qkv.weight,
         
     | 
| 320 | 
         
            +
                        backbone_top=self.backbone.trunk.blocks[-1].attn.qkv.weight,
         
     | 
| 321 | 
         
            +
                        head=self.head[0].weight
         
     | 
| 322 | 
         
            +
                    )
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                def get_min_image_size(self):
         
     | 
| 325 | 
         
            +
                    min_pixels = self.image_processor.min_pixels
         
     | 
| 326 | 
         
            +
                    max_pixels = self.image_processor.max_pixels
         
     | 
| 327 | 
         
            +
                    height = int(min_pixels**0.5)
         
     | 
| 328 | 
         
            +
                    width = int(min_pixels**0.5)
         
     | 
| 329 | 
         
            +
                    patch_size = self.image_processor.patch_size
         
     | 
| 330 | 
         
            +
                    hidden_stride = self.image_processor.hidden_stride
         
     | 
| 331 | 
         
            +
                    height, width = self.smart_resize(height, width, patch_size * hidden_stride, min_pixels, max_pixels)
         
     | 
| 332 | 
         
            +
                    return height, width
         
     | 
| 333 | 
         
            +
                
         
     | 
| 334 | 
         
            +
                def get_image_size(self):
         
     | 
| 335 | 
         
            +
                    min_pixels = self.image_processor.min_pixels
         
     | 
| 336 | 
         
            +
                    max_pixels = self.image_processor.max_pixels
         
     | 
| 337 | 
         
            +
                    num_pixels = (min_pixels+max_pixels) / 2
         
     | 
| 338 | 
         
            +
                    height = int(num_pixels**0.5)
         
     | 
| 339 | 
         
            +
                    width = int(num_pixels**0.5)
         
     | 
| 340 | 
         
            +
                    patch_size = self.image_processor.patch_size
         
     | 
| 341 | 
         
            +
                    hidden_stride = self.image_processor.hidden_stride
         
     | 
| 342 | 
         
            +
                    height, width = self.smart_resize(height, width, patch_size * hidden_stride, min_pixels, max_pixels)
         
     | 
| 343 | 
         
            +
                    return height, width
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                def get_token_length(self, width: int,
         
     | 
| 346 | 
         
            +
                                        height: int, 
         
     | 
| 347 | 
         
            +
                                        n_frames: int = 1,
         
     | 
| 348 | 
         
            +
                                        num_images: int = 1):
         
     | 
| 349 | 
         
            +
                    patch_size = self.image_processor.patch_size
         
     | 
| 350 | 
         
            +
                    temporal_patch_size = self.image_processor.temporal_patch_size
         
     | 
| 351 | 
         
            +
                    hidden_stride = self.image_processor.hidden_stride
         
     | 
| 352 | 
         
            +
                    min_pixels = self.image_processor.min_pixels
         
     | 
| 353 | 
         
            +
                    max_pixels = self.image_processor.max_pixels
         
     | 
| 354 | 
         
            +
                    
         
     | 
| 355 | 
         
            +
                    max_pixels = max_pixels // num_images
         
     | 
| 356 | 
         
            +
                    min_pixels = min(max_pixels, min_pixels)
         
     | 
| 357 | 
         
            +
                    
         
     | 
| 358 | 
         
            +
                    resized_height, resized_width = height, width
         
     | 
| 359 | 
         
            +
                    resized_height, resized_width = self.smart_resize(
         
     | 
| 360 | 
         
            +
                                height,
         
     | 
| 361 | 
         
            +
                                width,
         
     | 
| 362 | 
         
            +
                                factor=patch_size * hidden_stride,
         
     | 
| 363 | 
         
            +
                                min_pixels=min_pixels,
         
     | 
| 364 | 
         
            +
                                max_pixels=max_pixels,
         
     | 
| 365 | 
         
            +
                            )
         
     | 
| 366 | 
         
            +
                   
         
     | 
| 367 | 
         
            +
                    if n_frames % temporal_patch_size != 0:
         
     | 
| 368 | 
         
            +
                        n_frames = n_frames + temporal_patch_size - 1
         
     | 
| 369 | 
         
            +
                    grid_t = n_frames // temporal_patch_size
         
     | 
| 370 | 
         
            +
                    grid_h, grid_w = resized_height // patch_size // hidden_stride, resized_width // patch_size // hidden_stride
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    return grid_t * grid_w * grid_h
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                def mock_input(self):
         
     | 
| 375 | 
         
            +
                    height, width = self.get_min_image_size()
         
     | 
| 376 | 
         
            +
                    return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1))
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                def preprocess_image(self, images: Union[PIL.Image.Image, List[PIL.Image.Image]], 
         
     | 
| 379 | 
         
            +
                                        convert_to_rgb: Optional[bool] = True, 
         
     | 
| 380 | 
         
            +
                                        num_images: Optional[int] = 1,
         
     | 
| 381 | 
         
            +
                                        min_pixels: Optional[int] = None, 
         
     | 
| 382 | 
         
            +
                                        max_pixels: Optional[int] = None,
         
     | 
| 383 | 
         
            +
                                        multimodal_type: Optional[str] = 'single_image'):
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                    patch_size = self.image_processor.patch_size # 14
         
     | 
| 387 | 
         
            +
                    temporal_patch_size = self.image_processor.temporal_patch_size # 1
         
     | 
| 388 | 
         
            +
                    hidden_stride = self.image_processor.hidden_stride # 2
         
     | 
| 389 | 
         
            +
                    min_pixels = min_pixels or self.image_processor.min_pixels # 200704
         
     | 
| 390 | 
         
            +
                    max_pixels = max_pixels or self.image_processor.max_pixels # 3211264
         
     | 
| 391 | 
         
            +
                    
         
     | 
| 392 | 
         
            +
                    max_pixels = max_pixels // num_images
         
     | 
| 393 | 
         
            +
                    min_pixels = min(max_pixels, min_pixels)
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    if not isinstance(images, list):
         
     | 
| 396 | 
         
            +
                        images = [images]
         
     | 
| 397 | 
         
            +
                    if multimodal_type == 'video':
         
     | 
| 398 | 
         
            +
                        assert len(images) >= 1
         
     | 
| 399 | 
         
            +
                    else:
         
     | 
| 400 | 
         
            +
                        assert len(images) == 1
         
     | 
| 401 | 
         
            +
                    images = [image.convert("RGB") if convert_to_rgb and image.mode != 'RGB' else image for image in images ]
         
     | 
| 402 | 
         
            +
                    # images = [np.array(image) for image in images]
         
     | 
| 403 | 
         
            +
                    
         
     | 
| 404 | 
         
            +
                    width, height = images[0].size
         
     | 
| 405 | 
         
            +
                    resized_height, resized_width = height, width
         
     | 
| 406 | 
         
            +
                    processed_images = []
         
     | 
| 407 | 
         
            +
                    for image in images:
         
     | 
| 408 | 
         
            +
                        resized_height, resized_width = self.smart_resize(
         
     | 
| 409 | 
         
            +
                            height,
         
     | 
| 410 | 
         
            +
                            width,
         
     | 
| 411 | 
         
            +
                            factor=patch_size * hidden_stride,
         
     | 
| 412 | 
         
            +
                            min_pixels=min_pixels,
         
     | 
| 413 | 
         
            +
                            max_pixels=max_pixels,
         
     | 
| 414 | 
         
            +
                        )
         
     | 
| 415 | 
         
            +
                        new_size = dict(height=resized_height, width=resized_width)
         
     | 
| 416 | 
         
            +
                        image_pt = self.image_processor.preprocess(image, size=new_size, return_tensors="np")['pixel_values'][0]
         
     | 
| 417 | 
         
            +
                        
         
     | 
| 418 | 
         
            +
                        processed_images.append(image_pt)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                    patches = np.array(processed_images)
         
     | 
| 421 | 
         
            +
                    # if data_format == ChannelDimension.LAST:
         
     | 
| 422 | 
         
            +
                    #     patches = patches.transpose(0, 3, 1, 2)
         
     | 
| 423 | 
         
            +
                    if patches.shape[0] % temporal_patch_size != 0:
         
     | 
| 424 | 
         
            +
                        repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
         
     | 
| 425 | 
         
            +
                        patches = np.concatenate([patches, repeats], axis=0)
         
     | 
| 426 | 
         
            +
                    channel = patches.shape[1]
         
     | 
| 427 | 
         
            +
                    grid_t = patches.shape[0] // temporal_patch_size # 1
         
     | 
| 428 | 
         
            +
                    grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # 32, 32
         
     | 
| 429 | 
         
            +
                    
         
     | 
| 430 | 
         
            +
                    patches = patches.reshape(
         
     | 
| 431 | 
         
            +
                        grid_t,
         
     | 
| 432 | 
         
            +
                        temporal_patch_size,
         
     | 
| 433 | 
         
            +
                        channel,
         
     | 
| 434 | 
         
            +
                        grid_h // hidden_stride,
         
     | 
| 435 | 
         
            +
                        hidden_stride,
         
     | 
| 436 | 
         
            +
                        patch_size,
         
     | 
| 437 | 
         
            +
                        grid_w // hidden_stride,
         
     | 
| 438 | 
         
            +
                        hidden_stride,
         
     | 
| 439 | 
         
            +
                        patch_size,
         
     | 
| 440 | 
         
            +
                    )
         
     | 
| 441 | 
         
            +
                    patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
         
     | 
| 442 | 
         
            +
                    flatten_patches = patches.reshape(
         
     | 
| 443 | 
         
            +
                        grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
         
     | 
| 444 | 
         
            +
                    )
         
     | 
| 445 | 
         
            +
                    # 1024, 588
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                    image_placeholders = self.construct_image_placeholders((1, 1), data_type='video' if multimodal_type=='video' else 'image') # [-301, -300, -302, -305]
         
     | 
| 448 | 
         
            +
                    
         
     | 
| 449 | 
         
            +
                    # print(flatten_patches.shape, len(images))
         
     | 
| 450 | 
         
            +
                    return torch.tensor(flatten_patches), torch.tensor([[grid_t, grid_h, grid_w]]), image_placeholders
         
     | 
| 451 | 
         
            +
                
         
     | 
| 452 | 
         
            +
                def encode(self, pixel_values, grid_thws):
         
     | 
| 453 | 
         
            +
                    output = self.backbone(pixel_values, grid_thws, output_hidden_states=True, return_dict=True)
         
     | 
| 454 | 
         
            +
                    features = output.hidden_states[-1]
         
     | 
| 455 | 
         
            +
                    # default: false
         
     | 
| 456 | 
         
            +
                    # if self.config.drop_cls_token:
         
     | 
| 457 | 
         
            +
                    #     features = features[:, 1:, :]
         
     | 
| 458 | 
         
            +
                    
         
     | 
| 459 | 
         
            +
                    # refer to qwen2.5-vl patchmerger
         
     | 
| 460 | 
         
            +
                    seq_len, _ = features.shape
         
     | 
| 461 | 
         
            +
                    features = features.reshape(seq_len//(self.config.hidden_stride ** 2), -1)
         
     | 
| 462 | 
         
            +
                    
         
     | 
| 463 | 
         
            +
                    return features
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                def forward(self, pixel_values, grid_thws) -> torch.Tensor:  # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
         
     | 
| 466 | 
         
            +
                    features = self.encode(pixel_values, grid_thws)
         
     | 
| 467 | 
         
            +
                    logits = self.head(features)
         
     | 
| 468 | 
         
            +
                    tokens = self.tokenize(logits)
         
     | 
| 469 | 
         
            +
                    # tokens' shape is [#Token, VocabSize-5], so padding with [#Token, 5], after
         
     | 
| 470 | 
         
            +
                    # which, tokens' shape should become [#Token, VocabSize];
         
     | 
| 471 | 
         
            +
                    # this is different from original aimv2 which has [BatchSize, #Token, VocabSize-5]
         
     | 
| 472 | 
         
            +
                    token_len, _ = tokens.shape
         
     | 
| 473 | 
         
            +
                    padding_tensor = torch.zeros(size=(token_len, len(IMAGE_INDICATOR_IDS)),
         
     | 
| 474 | 
         
            +
                                                 dtype=tokens.dtype,
         
     | 
| 475 | 
         
            +
                                                 device=tokens.device,
         
     | 
| 476 | 
         
            +
                                                 layout=tokens.layout,
         
     | 
| 477 | 
         
            +
                                                 requires_grad=False)
         
     | 
| 478 | 
         
            +
                    tokens = torch.cat((tokens, padding_tensor), dim=1)
         
     | 
| 479 | 
         
            +
                    return tokens
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
            AutoModel.register(Aimv2VisualTokenizerConfig, Aimv2VisualTokenizer)
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 487 | 
         
            +
            #                           Visual Generator
         
     | 
| 488 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 489 | 
         
            +
            from .configuration_yak import YakConfig
         
     | 
| 490 | 
         
            +
            from .modeling_yak import YakModel
         
     | 
| 491 | 
         
            +
            AutoConfig.register("yak", YakConfig)
         
     | 
| 492 | 
         
            +
            AutoModel.register(YakConfig, YakModel)
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 497 | 
         
            +
            #                               OvisU1
         
     | 
| 498 | 
         
            +
            # ----------------------------------------------------------------------
         
     | 
| 499 | 
         
            +
            class VisualEmbedding(torch.nn.Embedding):
         
     | 
| 500 | 
         
            +
                def forward(self, visual_tokens: Tensor) -> Tensor:
         
     | 
| 501 | 
         
            +
                    if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
         
     | 
| 502 | 
         
            +
                        return super().forward(visual_tokens)
         
     | 
| 503 | 
         
            +
                    return torch.matmul(visual_tokens, self.weight)
         
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
                def reset_parameters(self, mean=0., std=1.) -> None:
         
     | 
| 506 | 
         
            +
                    init.normal_(self.weight, mean=mean, std=std)
         
     | 
| 507 | 
         
            +
                    self._fill_padding_idx_with_zero()
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
            class OvisU1PreTrainedModel(PreTrainedModel):
         
     | 
| 511 | 
         
            +
                config_class = OvisU1Config
         
     | 
| 512 | 
         
            +
                base_model_prefix = "ovis_u1"
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
            class OvisU1(OvisU1PreTrainedModel):
         
     | 
| 516 | 
         
            +
                
         
     | 
| 517 | 
         
            +
                def __init__(self, config: OvisU1Config, *inputs, **kwargs):
         
     | 
| 518 | 
         
            +
                    super().__init__(config, *inputs, **kwargs)
         
     | 
| 519 | 
         
            +
                    attn_kwargs = dict()
         
     | 
| 520 | 
         
            +
                    if self.config.llm_attn_implementation:
         
     | 
| 521 | 
         
            +
                        attn_kwargs['attn_implementation'] = self.config.llm_attn_implementation
         
     | 
| 522 | 
         
            +
                    self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
         
     | 
| 523 | 
         
            +
                    assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
         
     | 
| 524 | 
         
            +
                    self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
         
     | 
| 525 | 
         
            +
                    self.visual_tokenizer = AutoModel.from_config(self.config.visual_tokenizer_config,
         
     | 
| 526 | 
         
            +
                                                                image_processor_name_or_path=self.config.name_or_path)
         
     | 
| 527 | 
         
            +
                    self.visual_generator = AutoModel.from_config(self.config.visual_generator_config)
         
     | 
| 528 | 
         
            +
                    self.vte = VisualEmbedding(self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size,
         
     | 
| 529 | 
         
            +
                                                device=self.visual_tokenizer.device, dtype=self.visual_tokenizer.dtype)
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    def _merge_modules(modules_list: tuple):
         
     | 
| 532 | 
         
            +
                        merged_modules = []
         
     | 
| 533 | 
         
            +
                        for modules in modules_list:
         
     | 
| 534 | 
         
            +
                            merged_modules.extend(modules if modules else [])
         
     | 
| 535 | 
         
            +
                        return merged_modules
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                    self._no_split_modules = _merge_modules((self.llm._no_split_modules, self.visual_tokenizer._no_split_modules))
         
     | 
| 538 | 
         
            +
                    self._skip_keys_device_placement = self.llm._skip_keys_device_placement
         
     | 
| 539 | 
         
            +
                    self._keep_in_fp32_modules = _merge_modules(
         
     | 
| 540 | 
         
            +
                        (self.llm._keep_in_fp32_modules, self.visual_tokenizer._keep_in_fp32_modules))
         
     | 
| 541 | 
         
            +
                    self._supports_flash_attn_2 = True
         
     | 
| 542 | 
         
            +
                    self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.is_parallelizable, self.visual_generator.is_parallelizable))
         
     | 
| 543 | 
         
            +
                    self.supports_gradient_checkpointing = all(
         
     | 
| 544 | 
         
            +
                        (self.llm.supports_gradient_checkpointing, self.visual_tokenizer.supports_gradient_checkpointing, self.visual_generator.supports_gradient_checkpointing))
         
     | 
| 545 | 
         
            +
                    self._supports_sdpa = all((self.llm._supports_sdpa, self.visual_tokenizer._supports_sdpa, self.visual_generator._supports_sdpa))
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                def get_text_tokenizer(self):
         
     | 
| 548 | 
         
            +
                    return self.text_tokenizer
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                def get_visual_tokenizer(self):
         
     | 
| 551 | 
         
            +
                    return self.visual_tokenizer
         
     | 
| 552 | 
         
            +
                
         
     | 
| 553 | 
         
            +
                def get_visual_generator(self):
         
     | 
| 554 | 
         
            +
                    return self.visual_generator
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                def tie_weights(self):
         
     | 
| 557 | 
         
            +
                    if not self.config.disable_tie_weight:
         
     | 
| 558 | 
         
            +
                        self.get_llm().tie_weights()
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
                def get_lm_head(self):
         
     | 
| 561 | 
         
            +
                    return self.get_llm().get_output_embeddings()
         
     | 
| 562 | 
         
            +
             
     | 
| 563 | 
         
            +
                def get_llm(self):
         
     | 
| 564 | 
         
            +
                    return self.llm
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                def get_vte(self):
         
     | 
| 567 | 
         
            +
                    return self.vte
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                def get_wte(self):
         
     | 
| 570 | 
         
            +
                    return self.llm.get_input_embeddings()
         
     | 
| 571 | 
         
            +
             
     | 
| 572 | 
         
            +
                def get_conversation_formatter(self) -> ConversationFormatter:
         
     | 
| 573 | 
         
            +
                    if getattr(self, 'conversation_formatter', None) is None:
         
     | 
| 574 | 
         
            +
                        self.conversation_formatter = getattr(import_module(".configuration_ovis_u1", __package__),
         
     | 
| 575 | 
         
            +
                                                              self.config.conversation_formatter_class)(self.text_tokenizer)
         
     | 
| 576 | 
         
            +
                    return self.conversation_formatter
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                def merge_multimodal(
         
     | 
| 579 | 
         
            +
                        self,
         
     | 
| 580 | 
         
            +
                        text_input_ids: torch.Tensor,
         
     | 
| 581 | 
         
            +
                        text_attention_masks: torch.Tensor,
         
     | 
| 582 | 
         
            +
                        text_labels: Optional[torch.Tensor],
         
     | 
| 583 | 
         
            +
                        pixel_values: Optional[torch.Tensor],
         
     | 
| 584 | 
         
            +
                        grid_thws: Optional[torch.Tensor],
         
     | 
| 585 | 
         
            +
                        left_padding: bool = False
         
     | 
| 586 | 
         
            +
                ):
         
     | 
| 587 | 
         
            +
                    input_device = text_input_ids.device
         
     | 
| 588 | 
         
            +
                    visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
         
     | 
| 589 | 
         
            +
                    visual_indicator_embeds = self.get_vte()(
         
     | 
| 590 | 
         
            +
                        torch.tensor(
         
     | 
| 591 | 
         
            +
                            list(range(visual_vocab_szie - 5, visual_vocab_szie)),
         
     | 
| 592 | 
         
            +
                            dtype=torch.long,
         
     | 
| 593 | 
         
            +
                            device=self.get_visual_tokenizer().device
         
     | 
| 594 | 
         
            +
                        )
         
     | 
| 595 | 
         
            +
                    ).to(device=input_device)
         
     | 
| 596 | 
         
            +
             
     | 
| 597 | 
         
            +
                    if self.training:
         
     | 
| 598 | 
         
            +
                        # When training, to be compatible with deepspeed zero, each sample has to include pixel_value tensor.
         
     | 
| 599 | 
         
            +
                        # For text-only sample, one can simply use a full zero tensor as pixel_value, which will be ignored
         
     | 
| 600 | 
         
            +
                        # (see below in this function); so, the gradient will not be affected.
         
     | 
| 601 | 
         
            +
                        num_images = [x.prod() // (self.visual_tokenizer.config.hidden_stride**2) for x in grid_thws]
         
     | 
| 602 | 
         
            +
                        
         
     | 
| 603 | 
         
            +
                        visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
                        visual_embeds_ = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
         
     | 
| 606 | 
         
            +
                                                    split_size_or_sections=num_images, dim=0)
         
     | 
| 607 | 
         
            +
                        
         
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
                        visual_input_ids_ = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
         
     | 
| 611 | 
         
            +
                                                       split_size_or_sections=num_images, dim=0)
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                        visual_labels_ = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
         
     | 
| 615 | 
         
            +
                                         visual_input_ids_]
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
                        
         
     | 
| 618 | 
         
            +
                        visual_embeds = []
         
     | 
| 619 | 
         
            +
                        visual_input_ids = []
         
     | 
| 620 | 
         
            +
                        visual_labels = []
         
     | 
| 621 | 
         
            +
                        ind = 0
         
     | 
| 622 | 
         
            +
                        for text_input_id in text_input_ids:
         
     | 
| 623 | 
         
            +
                            image_atom_positions = torch.where(torch.eq(text_input_id, IMAGE_ATOM_ID))[0].tolist()
         
     | 
| 624 | 
         
            +
                            n = len(image_atom_positions)
         
     | 
| 625 | 
         
            +
                            if n > 0:
         
     | 
| 626 | 
         
            +
                                visual_embeds.append(visual_embeds_[ind:ind+n])
         
     | 
| 627 | 
         
            +
                                visual_input_ids.append(visual_input_ids_[ind:ind+n])
         
     | 
| 628 | 
         
            +
                                visual_labels.append(visual_labels_[ind:ind+n])
         
     | 
| 629 | 
         
            +
                                ind += n
         
     | 
| 630 | 
         
            +
                            else:
         
     | 
| 631 | 
         
            +
                                visual_embeds.append(visual_embeds_[ind:ind+1])
         
     | 
| 632 | 
         
            +
                                visual_input_ids.append(visual_input_ids_[ind:ind+1])
         
     | 
| 633 | 
         
            +
                                visual_labels.append(visual_labels_[ind:ind+1])
         
     | 
| 634 | 
         
            +
                                ind += 1
         
     | 
| 635 | 
         
            +
                            
         
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
                    else:
         
     | 
| 638 | 
         
            +
                        # TODO: Not modified yet
         
     | 
| 639 | 
         
            +
                        # When inference, sample can include only text with `None` pixel_value
         
     | 
| 640 | 
         
            +
                        # num_images = [x.shape[0] if x is not None else 0 for x in pixel_values]
         
     | 
| 641 | 
         
            +
                        num_images = [x.prod() // (self.visual_tokenizer.config.hidden_stride**2) if x is not None else 0 for x in grid_thws]
         
     | 
| 642 | 
         
            +
                        if sum(num_images) > 0:
         
     | 
| 643 | 
         
            +
                            visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
         
     | 
| 644 | 
         
            +
                            try:
         
     | 
| 645 | 
         
            +
                                visual_embeds_ = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
         
     | 
| 646 | 
         
            +
                                                    split_size_or_sections=num_images, dim=0)
         
     | 
| 647 | 
         
            +
                            except Exception as e:
         
     | 
| 648 | 
         
            +
                                print(e)
         
     | 
| 649 | 
         
            +
                                print(pixel_values.shape, grid_thws.shape, visual_tokens.shape, num_images)
         
     | 
| 650 | 
         
            +
                        
         
     | 
| 651 | 
         
            +
             
     | 
| 652 | 
         
            +
                            visual_input_ids_ = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
         
     | 
| 653 | 
         
            +
                                                        split_size_or_sections=num_images, dim=0)
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                            visual_labels_ = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
         
     | 
| 657 | 
         
            +
                                            visual_input_ids_]
         
     | 
| 658 | 
         
            +
                            
         
     | 
| 659 | 
         
            +
                            visual_embeds = []
         
     | 
| 660 | 
         
            +
                            visual_input_ids = []
         
     | 
| 661 | 
         
            +
                            visual_labels = []
         
     | 
| 662 | 
         
            +
                            ind = 0
         
     | 
| 663 | 
         
            +
                            for text_input_id in text_input_ids:
         
     | 
| 664 | 
         
            +
                                image_atom_positions = torch.where(torch.eq(text_input_id, IMAGE_ATOM_ID))[0].tolist()
         
     | 
| 665 | 
         
            +
                                n = len(image_atom_positions)
         
     | 
| 666 | 
         
            +
                                if n > 0:
         
     | 
| 667 | 
         
            +
                                    visual_embeds.append(visual_embeds_[ind:ind+n])
         
     | 
| 668 | 
         
            +
                                    visual_input_ids.append(visual_input_ids_[ind:ind+n])
         
     | 
| 669 | 
         
            +
                                    visual_labels.append(visual_labels_[ind:ind+n])
         
     | 
| 670 | 
         
            +
                                    ind += n
         
     | 
| 671 | 
         
            +
                                else:
         
     | 
| 672 | 
         
            +
                                    visual_embeds.append(visual_embeds_[ind:ind+1])
         
     | 
| 673 | 
         
            +
                                    visual_input_ids.append(visual_input_ids_[ind:ind+1])
         
     | 
| 674 | 
         
            +
                                    visual_labels.append(visual_labels_[ind:ind+1])
         
     | 
| 675 | 
         
            +
                                    ind += 1
         
     | 
| 676 | 
         
            +
                                    
         
     | 
| 677 | 
         
            +
                        else:
         
     | 
| 678 | 
         
            +
                            # just placeholders
         
     | 
| 679 | 
         
            +
                            visual_embeds = [None] * len(num_images)
         
     | 
| 680 | 
         
            +
                            visual_input_ids = [None] * len(num_images)
         
     | 
| 681 | 
         
            +
                            visual_labels = [None] * len(num_images)
         
     | 
| 682 | 
         
            +
                        
         
     | 
| 683 | 
         
            +
                    # just placeholders
         
     | 
| 684 | 
         
            +
                    if text_labels is None:
         
     | 
| 685 | 
         
            +
                        text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
         
     | 
| 686 | 
         
            +
             
     | 
| 687 | 
         
            +
                    input_embeds = []
         
     | 
| 688 | 
         
            +
                    attention_masks = []
         
     | 
| 689 | 
         
            +
                    labels = []
         
     | 
| 690 | 
         
            +
                    input_img_poss = []
         
     | 
| 691 | 
         
            +
                    for text_input_id, text_label, text_attention_mask, visual_embed, visual_input_id, visual_label in zip(
         
     | 
| 692 | 
         
            +
                        text_input_ids, text_labels, text_attention_masks, visual_embeds, visual_input_ids, visual_labels
         
     | 
| 693 | 
         
            +
                    ):
         
     | 
| 694 | 
         
            +
                        placeholder_token_mask = torch.lt(text_input_id, 0)
         
     | 
| 695 | 
         
            +
                        text_embed = self.get_wte()(torch.masked_fill(text_input_id, placeholder_token_mask, 0))
         
     | 
| 696 | 
         
            +
                        for i, indicator_id in enumerate(IMAGE_INDICATOR_IDS):
         
     | 
| 697 | 
         
            +
                            text_embed[text_input_id == indicator_id] = visual_indicator_embeds[i]
         
     | 
| 698 | 
         
            +
                        image_atom_positions = torch.where(torch.eq(text_input_id, IMAGE_ATOM_ID))[0].tolist()
         
     | 
| 699 | 
         
            +
                        if len(image_atom_positions) > 0:
         
     | 
| 700 | 
         
            +
                            input_embed_parts = []
         
     | 
| 701 | 
         
            +
                            attention_mask_parts = []
         
     | 
| 702 | 
         
            +
                            label_parts = []
         
     | 
| 703 | 
         
            +
                            input_img_pos_parts = []
         
     | 
| 704 | 
         
            +
                            prev_image_atom_position = -1
         
     | 
| 705 | 
         
            +
                            for index, image_atom_position in enumerate(image_atom_positions):
         
     | 
| 706 | 
         
            +
                                input_embed_parts.append(
         
     | 
| 707 | 
         
            +
                                    text_embed[prev_image_atom_position + 1:image_atom_position, :])
         
     | 
| 708 | 
         
            +
                                label_parts.append(
         
     | 
| 709 | 
         
            +
                                    text_label[prev_image_atom_position + 1:image_atom_position])
         
     | 
| 710 | 
         
            +
                                input_img_pos_parts.append(
         
     | 
| 711 | 
         
            +
                                    torch.zeros_like(text_label[prev_image_atom_position + 1:image_atom_position])
         
     | 
| 712 | 
         
            +
                                )
         
     | 
| 713 | 
         
            +
                                attention_mask_parts.append(
         
     | 
| 714 | 
         
            +
                                    text_attention_mask[prev_image_atom_position + 1:image_atom_position])
         
     | 
| 715 | 
         
            +
                                input_embed_parts.append(visual_embed[index])
         
     | 
| 716 | 
         
            +
                                attention_mask_parts.append(
         
     | 
| 717 | 
         
            +
                                    torch.ones_like(visual_label[index], dtype=torch.bool))
         
     | 
| 718 | 
         
            +
                                label_parts.append(visual_label[index])
         
     | 
| 719 | 
         
            +
                                input_img_pos_parts.append(
         
     | 
| 720 | 
         
            +
                                    torch.ones_like(visual_label[index])
         
     | 
| 721 | 
         
            +
                                )
         
     | 
| 722 | 
         
            +
                                prev_image_atom_position = image_atom_position
         
     | 
| 723 | 
         
            +
                            if prev_image_atom_position + 1 < text_input_id.shape[0]:
         
     | 
| 724 | 
         
            +
                                input_embed_parts.append(
         
     | 
| 725 | 
         
            +
                                    text_embed[prev_image_atom_position + 1:, :])
         
     | 
| 726 | 
         
            +
                                attention_mask_parts.append(
         
     | 
| 727 | 
         
            +
                                    text_attention_mask[prev_image_atom_position + 1:])
         
     | 
| 728 | 
         
            +
                                label_parts.append(
         
     | 
| 729 | 
         
            +
                                    text_label[prev_image_atom_position + 1:])
         
     | 
| 730 | 
         
            +
                                input_img_pos_parts.append(
         
     | 
| 731 | 
         
            +
                                    torch.zeros_like(text_label[prev_image_atom_position + 1:])
         
     | 
| 732 | 
         
            +
                                )
         
     | 
| 733 | 
         
            +
                            input_embed = torch.cat(input_embed_parts, dim=0)
         
     | 
| 734 | 
         
            +
                            attention_mask = torch.cat(attention_mask_parts, dim=0)
         
     | 
| 735 | 
         
            +
                            label = torch.cat(label_parts, dim=0)
         
     | 
| 736 | 
         
            +
                            input_img_pos = torch.cat(input_img_pos_parts, dim=0)
         
     | 
| 737 | 
         
            +
                        else:
         
     | 
| 738 | 
         
            +
                            input_embed = text_embed
         
     | 
| 739 | 
         
            +
                            attention_mask = text_attention_mask
         
     | 
| 740 | 
         
            +
                            label = text_label
         
     | 
| 741 | 
         
            +
                            input_img_pos = torch.zeros_like(text_label)
         
     | 
| 742 | 
         
            +
                            if self.training:
         
     | 
| 743 | 
         
            +
                                # Make visual_embed & visual_indicator_embeds involved in the backward graph,
         
     | 
| 744 | 
         
            +
                                # to be compatible with deepspeed zero and ddp.
         
     | 
| 745 | 
         
            +
                                input_embed += torch.sum(visual_embed[0] * 0.0) + torch.sum(visual_indicator_embeds * 0.0)
         
     | 
| 746 | 
         
            +
                        input_embeds.append(input_embed)
         
     | 
| 747 | 
         
            +
                        attention_masks.append(attention_mask)
         
     | 
| 748 | 
         
            +
                        labels.append(label)
         
     | 
| 749 | 
         
            +
                        input_img_poss.append(input_img_pos)
         
     | 
| 750 | 
         
            +
             
     | 
| 751 | 
         
            +
                    batch_input_embeds = self.pad_truncate_sequence(input_embeds, batch_first=True, padding_value=0.0, left_padding=left_padding)
         
     | 
| 752 | 
         
            +
                    batch_attention_mask = self.pad_truncate_sequence(attention_masks, batch_first=True, padding_value=False, left_padding=left_padding)
         
     | 
| 753 | 
         
            +
                    batch_labels = self.pad_truncate_sequence(labels, batch_first=True, padding_value=IGNORE_ID, left_padding=left_padding)
         
     | 
| 754 | 
         
            +
                    batch_input_img_labels = self.pad_truncate_sequence(input_img_poss, batch_first=True, padding_value=0.0, left_padding=left_padding)
         
     | 
| 755 | 
         
            +
             
     | 
| 756 | 
         
            +
                    return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask, batch_input_img_labels
         
     | 
| 757 | 
         
            +
             
     | 
| 758 | 
         
            +
                def pad_truncate_sequence(self, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
         
     | 
| 759 | 
         
            +
                    if left_padding == False:
         
     | 
| 760 | 
         
            +
                        pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
         
     | 
| 761 | 
         
            +
                        return pad_sequence[:,:self.config.multimodal_max_length]
         
     | 
| 762 | 
         
            +
                    else:
         
     | 
| 763 | 
         
            +
                        pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
         
     | 
| 764 | 
         
            +
                        return pad_sequence[:,-self.config.multimodal_max_length:]
         
     | 
| 765 | 
         
            +
             
     | 
| 766 | 
         
            +
                def preprocess_inputs(
         
     | 
| 767 | 
         
            +
                    self,
         
     | 
| 768 | 
         
            +
                    text_or_conversations: Union[List[Dict], str],
         
     | 
| 769 | 
         
            +
                    images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]],
         
     | 
| 770 | 
         
            +
                    generation_preface='',
         
     | 
| 771 | 
         
            +
                    return_labels=False,
         
     | 
| 772 | 
         
            +
                    propagate_exception=True,
         
     | 
| 773 | 
         
            +
                    frame_selector=None,
         
     | 
| 774 | 
         
            +
                    multimodal_type="single_image",
         
     | 
| 775 | 
         
            +
                    fix_sample_overall_length_navit=False,
         
     | 
| 776 | 
         
            +
                    min_pixels=None,
         
     | 
| 777 | 
         
            +
                    max_pixels=None,
         
     | 
| 778 | 
         
            +
                    enable_thinking=False
         
     | 
| 779 | 
         
            +
                ):
         
     | 
| 780 | 
         
            +
                    # convert text to conversations
         
     | 
| 781 | 
         
            +
                    if isinstance(text_or_conversations, str):
         
     | 
| 782 | 
         
            +
                        conversations = [{
         
     | 
| 783 | 
         
            +
                            "from": "human",
         
     | 
| 784 | 
         
            +
                            "value": text_or_conversations
         
     | 
| 785 | 
         
            +
                        }]
         
     | 
| 786 | 
         
            +
                    elif isinstance(text_or_conversations, list):
         
     | 
| 787 | 
         
            +
                        conversations = text_or_conversations
         
     | 
| 788 | 
         
            +
                    else:
         
     | 
| 789 | 
         
            +
                        raise ValueError(f'[{datetime.now()}] Invalid type of `text_or_conversations`, expected `List[Dict]` or `str`,'
         
     | 
| 790 | 
         
            +
                                         f' but got {type(text_or_conversations)}')
         
     | 
| 791 | 
         
            +
             
     | 
| 792 | 
         
            +
                    if frame_selector is not None:
         
     | 
| 793 | 
         
            +
                        conversations, images = frame_selector(conversations=conversations,frames=images,clear_prompt=True)
         
     | 
| 794 | 
         
            +
             
     | 
| 795 | 
         
            +
                    # format conversations
         
     | 
| 796 | 
         
            +
                    prompt, raw_input_ids, raw_labels = self.get_conversation_formatter().format(
         
     | 
| 797 | 
         
            +
                        conversations, generation_preface=generation_preface, enable_thinking=enable_thinking)
         
     | 
| 798 | 
         
            +
             
     | 
| 799 | 
         
            +
                    # place image placeholders
         
     | 
| 800 | 
         
            +
                    input_ids = []
         
     | 
| 801 | 
         
            +
                    labels = []
         
     | 
| 802 | 
         
            +
                    pixel_values = []
         
     | 
| 803 | 
         
            +
                    grid_thws = []
         
     | 
| 804 | 
         
            +
                    invalidate_label = False
         
     | 
| 805 | 
         
            +
                    image_token_indices = [i for i, v in enumerate(raw_input_ids) if v == IMAGE_TOKEN_ID or v == VIDEO_TOKEN_ID]
         
     | 
| 806 | 
         
            +
                    last_image_token_index = -1
         
     | 
| 807 | 
         
            +
                    for i in range(len(image_token_indices)):
         
     | 
| 808 | 
         
            +
                        head = 0 if i == 0 else image_token_indices[i - 1] + 1
         
     | 
| 809 | 
         
            +
                        tail = image_token_indices[i]
         
     | 
| 810 | 
         
            +
                        last_image_token_index = tail
         
     | 
| 811 | 
         
            +
                        input_ids.extend(raw_input_ids[head:tail])
         
     | 
| 812 | 
         
            +
                        labels.extend(raw_labels[head:tail])
         
     | 
| 813 | 
         
            +
                        try:
         
     | 
| 814 | 
         
            +
                            # currently, do not support multiple videos
         
     | 
| 815 | 
         
            +
                            if multimodal_type == "video":
         
     | 
| 816 | 
         
            +
                                image = images
         
     | 
| 817 | 
         
            +
                            else:
         
     | 
| 818 | 
         
            +
                                image = images[i]
         
     | 
| 819 | 
         
            +
                            raw_pixel_values, image_grid_thws, image_placeholders = self.visual_tokenizer.preprocess_image(
         
     | 
| 820 | 
         
            +
                                image, num_images=len(images) if fix_sample_overall_length_navit else 1, min_pixels=min_pixels, max_pixels=max_pixels,
         
     | 
| 821 | 
         
            +
                                multimodal_type=multimodal_type)
         
     | 
| 822 | 
         
            +
                        except Exception as e:
         
     | 
| 823 | 
         
            +
                            if propagate_exception:
         
     | 
| 824 | 
         
            +
                                raise e
         
     | 
| 825 | 
         
            +
                            logging.exception(e)
         
     | 
| 826 | 
         
            +
                            invalidate_label = True
         
     | 
| 827 | 
         
            +
                            # raw_pixel_values, image_placeholders = self.visual_tokenizer.mock_input() # TODO
         
     | 
| 828 | 
         
            +
                            raw_pixel_values, _ = self.visual_tokenizer.mock_input()
         
     | 
| 829 | 
         
            +
                            mock_image = transforms.ToPILImage()(raw_pixel_values[0])
         
     | 
| 830 | 
         
            +
                            raw_pixel_values, image_grid_thws, image_placeholders = self.visual_tokenizer.preprocess_image(
         
     | 
| 831 | 
         
            +
                                        mock_image, min_pixels=min_pixels, max_pixels=max_pixels)
         
     | 
| 832 | 
         
            +
                            
         
     | 
| 833 | 
         
            +
                        input_ids.extend(image_placeholders)
         
     | 
| 834 | 
         
            +
                        labels.extend([IGNORE_ID] * len(image_placeholders))
         
     | 
| 835 | 
         
            +
                        pixel_values.append(raw_pixel_values)
         
     | 
| 836 | 
         
            +
                        grid_thws.append(image_grid_thws)
         
     | 
| 837 | 
         
            +
                    input_ids.extend(raw_input_ids[last_image_token_index + 1:])
         
     | 
| 838 | 
         
            +
                    labels.extend(raw_labels[last_image_token_index + 1:])
         
     | 
| 839 | 
         
            +
             
     | 
| 840 | 
         
            +
                    # return tensors
         
     | 
| 841 | 
         
            +
                    input_ids = torch.tensor(input_ids, dtype=torch.long)
         
     | 
| 842 | 
         
            +
                    labels = torch.tensor([IGNORE_ID] * len(labels) if invalidate_label else labels, dtype=torch.long)
         
     | 
| 843 | 
         
            +
                    pixel_values = torch.cat(pixel_values, dim=0) if len(pixel_values) > 0 else None
         
     | 
| 844 | 
         
            +
                    grid_thws = torch.cat(grid_thws, dim=0) if len(grid_thws) > 0 else None
         
     | 
| 845 | 
         
            +
             
     | 
| 846 | 
         
            +
                    if return_labels:
         
     | 
| 847 | 
         
            +
                        return prompt, input_ids, pixel_values, grid_thws, labels
         
     | 
| 848 | 
         
            +
                    else:
         
     | 
| 849 | 
         
            +
                        return prompt, input_ids, pixel_values, grid_thws
         
     | 
| 850 | 
         
            +
             
     | 
| 851 | 
         
            +
                def generate(
         
     | 
| 852 | 
         
            +
                    self,
         
     | 
| 853 | 
         
            +
                    inputs: Optional[torch.Tensor] = None,
         
     | 
| 854 | 
         
            +
                    **kwargs,
         
     | 
| 855 | 
         
            +
                ) -> Union[GenerateOutput, torch.LongTensor]:
         
     | 
| 856 | 
         
            +
                    # assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
         
     | 
| 857 | 
         
            +
                    _, inputs_embeds, labels, attention_mask, input_img_labels = self.merge_multimodal(
         
     | 
| 858 | 
         
            +
                        text_input_ids=inputs,
         
     | 
| 859 | 
         
            +
                        text_attention_masks=kwargs.pop('attention_mask'),
         
     | 
| 860 | 
         
            +
                        text_labels=None,
         
     | 
| 861 | 
         
            +
                        pixel_values=kwargs.pop('pixel_values'),
         
     | 
| 862 | 
         
            +
                        grid_thws=kwargs.pop('grid_thws'),
         
     | 
| 863 | 
         
            +
                        left_padding=True
         
     | 
| 864 | 
         
            +
                    )
         
     | 
| 865 | 
         
            +
                    inputs_embeds = inputs_embeds.detach()
         
     | 
| 866 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 867 | 
         
            +
                    return self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
         
     | 
| 868 | 
         
            +
             
     | 
| 869 | 
         
            +
                def generate_condition(
         
     | 
| 870 | 
         
            +
                        self,
         
     | 
| 871 | 
         
            +
                        inputs: Optional[torch.Tensor] = None,
         
     | 
| 872 | 
         
            +
                        **kwargs,
         
     | 
| 873 | 
         
            +
                ):
         
     | 
| 874 | 
         
            +
                    # assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
         
     | 
| 875 | 
         
            +
                    _, inputs_embeds, labels, attention_mask, input_img_labels = self.merge_multimodal(
         
     | 
| 876 | 
         
            +
                        text_input_ids=inputs,
         
     | 
| 877 | 
         
            +
                        text_attention_masks=kwargs.pop('attention_mask'),
         
     | 
| 878 | 
         
            +
                        text_labels=None,
         
     | 
| 879 | 
         
            +
                        pixel_values=kwargs.pop('pixel_values'),
         
     | 
| 880 | 
         
            +
                        grid_thws=kwargs.pop('grid_thws'),
         
     | 
| 881 | 
         
            +
                        left_padding=True
         
     | 
| 882 | 
         
            +
                    )
         
     | 
| 883 | 
         
            +
                    inputs_embeds = inputs_embeds.detach()
         
     | 
| 884 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 885 | 
         
            +
                    device = self.llm.device
         
     | 
| 886 | 
         
            +
                    outputs = self.llm(inputs_embeds=inputs_embeds.to(device), 
         
     | 
| 887 | 
         
            +
                                        labels=labels.to(device), 
         
     | 
| 888 | 
         
            +
                                        attention_mask=attention_mask.to(device), 
         
     | 
| 889 | 
         
            +
                                        output_hidden_states=True, 
         
     | 
| 890 | 
         
            +
                                        **kwargs)
         
     | 
| 891 | 
         
            +
                    semantic_cond_0 = outputs.hidden_states[-1]
         
     | 
| 892 | 
         
            +
                    semantic_cond_1 = outputs.hidden_states[-2]
         
     | 
| 893 | 
         
            +
                    semantic_cond = torch.cat([semantic_cond_0, semantic_cond_1], dim=-1)
         
     | 
| 894 | 
         
            +
                    return dict(
         
     | 
| 895 | 
         
            +
                        txt=semantic_cond
         
     | 
| 896 | 
         
            +
                    )
         
     | 
| 897 | 
         
            +
                
         
     | 
| 898 | 
         
            +
                def generate_img(
         
     | 
| 899 | 
         
            +
                    self,
         
     | 
| 900 | 
         
            +
                    inputs: Optional[torch.Tensor] = None,
         
     | 
| 901 | 
         
            +
                    cond = None,
         
     | 
| 902 | 
         
            +
                    no_both_cond = None,
         
     | 
| 903 | 
         
            +
                    no_txt_cond = None,
         
     | 
| 904 | 
         
            +
                    **kwargs,
         
     | 
| 905 | 
         
            +
                ) -> Union[GenerateOutput, torch.LongTensor]:
         
     | 
| 906 | 
         
            +
                    if cond is None:
         
     | 
| 907 | 
         
            +
                        cond = self.generate_condition(inputs, **kwargs)
         
     | 
| 908 | 
         
            +
                    
         
     | 
| 909 | 
         
            +
                    height = kwargs.get('height', 1024)
         
     | 
| 910 | 
         
            +
                    width = kwargs.get('width', 1024)
         
     | 
| 911 | 
         
            +
                    num_steps = kwargs.get('num_steps', 50)
         
     | 
| 912 | 
         
            +
                    seed = kwargs.get('seed', 42)
         
     | 
| 913 | 
         
            +
                    img_cfg = kwargs.pop('img_cfg', 1.5)
         
     | 
| 914 | 
         
            +
                    txt_cfg = kwargs.pop('txt_cfg', 5)
         
     | 
| 915 | 
         
            +
                    yak_output = self.visual_generator.generate_image(
         
     | 
| 916 | 
         
            +
                        cond=cond, no_txt_cond=no_txt_cond, no_both_cond=no_both_cond,
         
     | 
| 917 | 
         
            +
                        height=height, width=width, 
         
     | 
| 918 | 
         
            +
                        num_steps=num_steps, seed=seed, 
         
     | 
| 919 | 
         
            +
                        img_cfg=img_cfg, txt_cfg=txt_cfg,
         
     | 
| 920 | 
         
            +
                        output_type="pil")
         
     | 
| 921 | 
         
            +
                    return yak_output
         
     | 
    	
        modeling_yak.py
    ADDED
    
    | 
         @@ -0,0 +1,1461 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Callable
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 4 | 
         
            +
            import collections.abc
         
     | 
| 5 | 
         
            +
            from itertools import repeat as iter_repeat
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 10 | 
         
            +
            import torchvision
         
     | 
| 11 | 
         
            +
            from torchvision import transforms
         
     | 
| 12 | 
         
            +
            from diffusers import AutoencoderKL
         
     | 
| 13 | 
         
            +
            from PIL import Image
         
     | 
| 14 | 
         
            +
            from PIL.ImageOps import exif_transpose
         
     | 
| 15 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 16 | 
         
            +
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 17 | 
         
            +
            from transformers.utils import ModelOutput
         
     | 
| 18 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .configuration_yak import YakConfig
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def _ntuple(n):
         
     | 
| 24 | 
         
            +
                def parse(x):
         
     | 
| 25 | 
         
            +
                    if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
         
     | 
| 26 | 
         
            +
                        x = tuple(x)
         
     | 
| 27 | 
         
            +
                        if len(x) == 1:
         
     | 
| 28 | 
         
            +
                            x = tuple(iter_repeat(x[0], n))
         
     | 
| 29 | 
         
            +
                        return x
         
     | 
| 30 | 
         
            +
                    return tuple(iter_repeat(x, n))
         
     | 
| 31 | 
         
            +
                return parse
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            to_1tuple = _ntuple(1)
         
     | 
| 35 | 
         
            +
            to_2tuple = _ntuple(2)
         
     | 
| 36 | 
         
            +
            to_3tuple = _ntuple(3)
         
     | 
| 37 | 
         
            +
            to_4tuple = _ntuple(4)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def as_tuple(x):
         
     | 
| 41 | 
         
            +
                if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
         
     | 
| 42 | 
         
            +
                    return tuple(x)
         
     | 
| 43 | 
         
            +
                if x is None or isinstance(x, (int, float, str)):
         
     | 
| 44 | 
         
            +
                    return (x,)
         
     | 
| 45 | 
         
            +
                else:
         
     | 
| 46 | 
         
            +
                    raise ValueError(f"Unknown type {type(x)}")
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            def as_list_of_2tuple(x):
         
     | 
| 50 | 
         
            +
                x = as_tuple(x)
         
     | 
| 51 | 
         
            +
                if len(x) == 1:
         
     | 
| 52 | 
         
            +
                    x = (x[0], x[0])
         
     | 
| 53 | 
         
            +
                assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
         
     | 
| 54 | 
         
            +
                lst = []
         
     | 
| 55 | 
         
            +
                for i in range(0, len(x), 2):
         
     | 
| 56 | 
         
            +
                    lst.append((x[i], x[i + 1]))
         
     | 
| 57 | 
         
            +
                return lst
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor=None, attn_mask=None) -> Tensor:
         
     | 
| 60 | 
         
            +
                if pe is None:
         
     | 
| 61 | 
         
            +
                    if attn_mask is not None and attn_mask.dtype != torch.bool:
         
     | 
| 62 | 
         
            +
                        attn_mask = attn_mask.to(q.dtype)
         
     | 
| 63 | 
         
            +
                    x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
         
     | 
| 64 | 
         
            +
                    x = rearrange(x, "B H L D -> B L (H D)")
         
     | 
| 65 | 
         
            +
                else:
         
     | 
| 66 | 
         
            +
                    q, k = apply_rope(q, k, pe)
         
     | 
| 67 | 
         
            +
                    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
         
     | 
| 68 | 
         
            +
                    x = rearrange(x, "B H L D -> B L (H D)")
         
     | 
| 69 | 
         
            +
                return x
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
         
     | 
| 73 | 
         
            +
                assert dim % 2 == 0
         
     | 
| 74 | 
         
            +
                scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
         
     | 
| 75 | 
         
            +
                omega = 1.0 / (theta**scale)
         
     | 
| 76 | 
         
            +
                out = torch.einsum("...n,d->...nd", pos, omega)
         
     | 
| 77 | 
         
            +
                out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
         
     | 
| 78 | 
         
            +
                out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
         
     | 
| 79 | 
         
            +
                return out.float()
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
         
     | 
| 83 | 
         
            +
                xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
         
     | 
| 84 | 
         
            +
                xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
         
     | 
| 85 | 
         
            +
                xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
         
     | 
| 86 | 
         
            +
                xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
         
     | 
| 87 | 
         
            +
                return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            class EmbedND(nn.Module):
         
     | 
| 91 | 
         
            +
                def __init__(self, dim: int, theta: int, axes_dim: list[int]):
         
     | 
| 92 | 
         
            +
                    super().__init__()
         
     | 
| 93 | 
         
            +
                    self.dim = dim
         
     | 
| 94 | 
         
            +
                    self.theta = theta
         
     | 
| 95 | 
         
            +
                    self.axes_dim = axes_dim
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def forward(self, ids: Tensor) -> Tensor:
         
     | 
| 98 | 
         
            +
                    n_axes = ids.shape[-1]
         
     | 
| 99 | 
         
            +
                    emb = torch.cat(
         
     | 
| 100 | 
         
            +
                        [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
         
     | 
| 101 | 
         
            +
                        dim=-3,
         
     | 
| 102 | 
         
            +
                    )
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    return emb.unsqueeze(1)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
         
     | 
| 108 | 
         
            +
                """
         
     | 
| 109 | 
         
            +
                Create sinusoidal timestep embeddings.
         
     | 
| 110 | 
         
            +
                :param t: a 1-D Tensor of N indices, one per batch element.
         
     | 
| 111 | 
         
            +
                                  These may be fractional.
         
     | 
| 112 | 
         
            +
                :param dim: the dimension of the output.
         
     | 
| 113 | 
         
            +
                :param max_period: controls the minimum frequency of the embeddings.
         
     | 
| 114 | 
         
            +
                :return: an (N, D) Tensor of positional embeddings.
         
     | 
| 115 | 
         
            +
                """
         
     | 
| 116 | 
         
            +
                t = time_factor * t
         
     | 
| 117 | 
         
            +
                half = dim // 2
         
     | 
| 118 | 
         
            +
                freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
         
     | 
| 119 | 
         
            +
                    t.device
         
     | 
| 120 | 
         
            +
                )
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                args = t[:, None].float() * freqs[None]
         
     | 
| 123 | 
         
            +
                embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         
     | 
| 124 | 
         
            +
                if dim % 2:
         
     | 
| 125 | 
         
            +
                    embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
         
     | 
| 126 | 
         
            +
                if torch.is_floating_point(t):
         
     | 
| 127 | 
         
            +
                    embedding = embedding.to(t)
         
     | 
| 128 | 
         
            +
                return embedding
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            class MLPEmbedder(nn.Module):
         
     | 
| 132 | 
         
            +
                def __init__(self, in_dim: int, hidden_dim: int):
         
     | 
| 133 | 
         
            +
                    super().__init__()
         
     | 
| 134 | 
         
            +
                    self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
         
     | 
| 135 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 136 | 
         
            +
                    self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                def forward(self, x: Tensor) -> Tensor:
         
     | 
| 139 | 
         
            +
                    return self.out_layer(self.silu(self.in_layer(x)))
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
            class RMSNorm(torch.nn.Module):
         
     | 
| 143 | 
         
            +
                def __init__(self, dim: int, scale_factor=1.0, eps:float=1e-6):
         
     | 
| 144 | 
         
            +
                    super().__init__()
         
     | 
| 145 | 
         
            +
                    self.scale = nn.Parameter(torch.ones(dim) * scale_factor)
         
     | 
| 146 | 
         
            +
                    self.eps = eps
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def forward(self, x: Tensor):
         
     | 
| 149 | 
         
            +
                    x_dtype = x.dtype
         
     | 
| 150 | 
         
            +
                    x = x.float()
         
     | 
| 151 | 
         
            +
                    rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
         
     | 
| 152 | 
         
            +
                    return (x * rrms).to(dtype=x_dtype) * self.scale
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            class QKNorm(torch.nn.Module):
         
     | 
| 156 | 
         
            +
                def __init__(self, dim: int):
         
     | 
| 157 | 
         
            +
                    super().__init__()
         
     | 
| 158 | 
         
            +
                    self.query_norm = RMSNorm(dim)
         
     | 
| 159 | 
         
            +
                    self.key_norm = RMSNorm(dim)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
         
     | 
| 162 | 
         
            +
                    q = self.query_norm(q)
         
     | 
| 163 | 
         
            +
                    k = self.key_norm(k)
         
     | 
| 164 | 
         
            +
                    return q.to(v), k.to(v)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            class SelfAttention(nn.Module):
         
     | 
| 168 | 
         
            +
                def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
         
     | 
| 169 | 
         
            +
                    super().__init__()
         
     | 
| 170 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 171 | 
         
            +
                    head_dim = dim // num_heads
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         
     | 
| 174 | 
         
            +
                    self.norm = QKNorm(head_dim)
         
     | 
| 175 | 
         
            +
                    self.proj = nn.Linear(dim, dim)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def forward(self, x: Tensor, pe: Tensor) -> Tensor:
         
     | 
| 178 | 
         
            +
                    qkv = self.qkv(x)
         
     | 
| 179 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
         
     | 
| 180 | 
         
            +
                    q, k = self.norm(q, k, v)
         
     | 
| 181 | 
         
            +
                    x = attention(q, k, v, pe=pe)
         
     | 
| 182 | 
         
            +
                    x = self.proj(x)
         
     | 
| 183 | 
         
            +
                    return x
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            @dataclass
         
     | 
| 187 | 
         
            +
            class ModulationOut:
         
     | 
| 188 | 
         
            +
                shift: Tensor
         
     | 
| 189 | 
         
            +
                scale: Tensor
         
     | 
| 190 | 
         
            +
                gate: Tensor
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            class Modulation(nn.Module):
         
     | 
| 194 | 
         
            +
                def __init__(self, dim: int, double: bool):
         
     | 
| 195 | 
         
            +
                    super().__init__()
         
     | 
| 196 | 
         
            +
                    self.is_double = double
         
     | 
| 197 | 
         
            +
                    self.multiplier = 6 if double else 3
         
     | 
| 198 | 
         
            +
                    self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
         
     | 
| 201 | 
         
            +
                    out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    return (
         
     | 
| 204 | 
         
            +
                        ModulationOut(*out[:3]),
         
     | 
| 205 | 
         
            +
                        ModulationOut(*out[3:]) if self.is_double else None,
         
     | 
| 206 | 
         
            +
                    )
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
            class TriModulation(nn.Module):
         
     | 
| 209 | 
         
            +
                def __init__(self, dim: int):
         
     | 
| 210 | 
         
            +
                    super().__init__()
         
     | 
| 211 | 
         
            +
                    self.multiplier = 9
         
     | 
| 212 | 
         
            +
                    self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
         
     | 
| 215 | 
         
            +
                    out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    return (
         
     | 
| 218 | 
         
            +
                        ModulationOut(*out[:3]),
         
     | 
| 219 | 
         
            +
                        ModulationOut(*out[3:6]),
         
     | 
| 220 | 
         
            +
                        ModulationOut(*out[6:]),
         
     | 
| 221 | 
         
            +
                    )
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            # from https://huggingface.co/stabilityai/stable-diffusion-3.5-medium
         
     | 
| 225 | 
         
            +
            class DoubleStreamXBlockProcessor:
         
     | 
| 226 | 
         
            +
                def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
         
     | 
| 227 | 
         
            +
                    img_mod1, img_mod2, img_mod3 = attn.img_mod(vec)
         
     | 
| 228 | 
         
            +
                    txt_mod1, txt_mod2 = attn.txt_mod(vec)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    # prepare image for attention
         
     | 
| 231 | 
         
            +
                    img_modulated = attn.img_norm1(img)
         
     | 
| 232 | 
         
            +
                    img_cos_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
         
     | 
| 233 | 
         
            +
                    img_qkv = attn.img_attn.qkv(img_cos_modulated)
         
     | 
| 234 | 
         
            +
                    img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 235 | 
         
            +
                    img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    # prepare image for self-attention
         
     | 
| 238 | 
         
            +
                    img_self_modulated = (1 + img_mod3.scale) * img_modulated + img_mod3.shift
         
     | 
| 239 | 
         
            +
                    img_self_qkv = attn.img_self_attn.qkv(img_self_modulated)
         
     | 
| 240 | 
         
            +
                    img_self_q, img_self_k, img_self_v = rearrange(img_self_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 241 | 
         
            +
                    img_self_q, img_self_k = attn.img_self_attn.norm(img_self_q, img_self_k, img_self_v)
         
     | 
| 242 | 
         
            +
                    txt_pe, img_pe = torch.split(pe, [txt.shape[1], img.shape[1]], dim=2)
         
     | 
| 243 | 
         
            +
                    img_self_attn = attention(img_self_q, img_self_k, img_self_v, pe=img_pe)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    # prepare txt for attention
         
     | 
| 246 | 
         
            +
                    txt_modulated = attn.txt_norm1(txt)
         
     | 
| 247 | 
         
            +
                    txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
         
     | 
| 248 | 
         
            +
                    txt_qkv = attn.txt_attn.qkv(txt_modulated)
         
     | 
| 249 | 
         
            +
                    txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 250 | 
         
            +
                    txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    # run actual attention
         
     | 
| 253 | 
         
            +
                    q = torch.cat((txt_q, img_q), dim=2)
         
     | 
| 254 | 
         
            +
                    k = torch.cat((txt_k, img_k), dim=2)
         
     | 
| 255 | 
         
            +
                    v = torch.cat((txt_v, img_v), dim=2)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    attn1 = attention(q, k, v, pe=pe)
         
     | 
| 258 | 
         
            +
                    txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    # calculate the img bloks
         
     | 
| 261 | 
         
            +
                    img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
         
     | 
| 262 | 
         
            +
                    img = img + img_mod3.gate * attn.img_self_attn.proj(img_self_attn)
         
     | 
| 263 | 
         
            +
                    img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    # calculate the txt bloks
         
     | 
| 266 | 
         
            +
                    txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
         
     | 
| 267 | 
         
            +
                    txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
         
     | 
| 268 | 
         
            +
                    return img, txt
         
     | 
| 269 | 
         
            +
                
         
     | 
| 270 | 
         
            +
            class DoubleStreamXBlock(nn.Module):
         
     | 
| 271 | 
         
            +
                def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
         
     | 
| 272 | 
         
            +
                    super().__init__()
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    mlp_hidden_dim = int(hidden_size * mlp_ratio)
         
     | 
| 275 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 276 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 277 | 
         
            +
                    self.img_mod = TriModulation(hidden_size)
         
     | 
| 278 | 
         
            +
                    self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 279 | 
         
            +
                    self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
         
     | 
| 280 | 
         
            +
                    self.img_self_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 283 | 
         
            +
                    self.img_mlp = nn.Sequential(
         
     | 
| 284 | 
         
            +
                        nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
         
     | 
| 285 | 
         
            +
                        nn.GELU(approximate="tanh"),
         
     | 
| 286 | 
         
            +
                        nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
         
     | 
| 287 | 
         
            +
                    )
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    self.txt_mod = Modulation(hidden_size, double=True)
         
     | 
| 290 | 
         
            +
                    self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 291 | 
         
            +
                    self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                    self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 294 | 
         
            +
                    self.txt_mlp = nn.Sequential(
         
     | 
| 295 | 
         
            +
                        nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
         
     | 
| 296 | 
         
            +
                        nn.GELU(approximate="tanh"),
         
     | 
| 297 | 
         
            +
                        nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
         
     | 
| 298 | 
         
            +
                    )
         
     | 
| 299 | 
         
            +
                    processor = DoubleStreamXBlockProcessor()
         
     | 
| 300 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 301 | 
         
            +
                
         
     | 
| 302 | 
         
            +
                def set_processor(self, processor) -> None:
         
     | 
| 303 | 
         
            +
                    self.processor = processor
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                def get_processor(self):
         
     | 
| 306 | 
         
            +
                    return self.processor
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                def forward(
         
     | 
| 309 | 
         
            +
                    self,
         
     | 
| 310 | 
         
            +
                    img: Tensor,
         
     | 
| 311 | 
         
            +
                    txt: Tensor,
         
     | 
| 312 | 
         
            +
                    vec: Tensor,
         
     | 
| 313 | 
         
            +
                    pe: Tensor,
         
     | 
| 314 | 
         
            +
                    image_proj: Tensor = None,
         
     | 
| 315 | 
         
            +
                    ip_scale: float =1.0,
         
     | 
| 316 | 
         
            +
                ) -> tuple[Tensor, Tensor]:
         
     | 
| 317 | 
         
            +
                    if image_proj is None:
         
     | 
| 318 | 
         
            +
                        return self.processor(self, img, txt, vec, pe)
         
     | 
| 319 | 
         
            +
                    else:
         
     | 
| 320 | 
         
            +
                        return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
            class SingleStreamBlockProcessor:
         
     | 
| 323 | 
         
            +
                def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
         
     | 
| 324 | 
         
            +
                    mod, _ = attn.modulation(vec)
         
     | 
| 325 | 
         
            +
                    x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
         
     | 
| 326 | 
         
            +
                    qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
         
     | 
| 329 | 
         
            +
                    q, k = attn.norm(q, k, v)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    # compute attention
         
     | 
| 332 | 
         
            +
                    attn_1 = attention(q, k, v, pe=pe)
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    # compute activation in mlp stream, cat again and run second linear layer
         
     | 
| 335 | 
         
            +
                    output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
         
     | 
| 336 | 
         
            +
                    output = x + mod.gate * output
         
     | 
| 337 | 
         
            +
                    return output
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
            class SingleStreamBlock(nn.Module):
         
     | 
| 341 | 
         
            +
                """
         
     | 
| 342 | 
         
            +
                A DiT block with parallel linear layers as described in
         
     | 
| 343 | 
         
            +
                https://arxiv.org/abs/2302.05442 and adapted modulation interface.
         
     | 
| 344 | 
         
            +
                """
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                def __init__(
         
     | 
| 347 | 
         
            +
                    self,
         
     | 
| 348 | 
         
            +
                    hidden_size: int,
         
     | 
| 349 | 
         
            +
                    num_heads: int,
         
     | 
| 350 | 
         
            +
                    mlp_ratio: float = 4.0,
         
     | 
| 351 | 
         
            +
                    qk_scale: float | None = None,
         
     | 
| 352 | 
         
            +
                ):
         
     | 
| 353 | 
         
            +
                    super().__init__()
         
     | 
| 354 | 
         
            +
                    self.hidden_dim = hidden_size
         
     | 
| 355 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 356 | 
         
            +
                    head_dim = hidden_size // num_heads
         
     | 
| 357 | 
         
            +
                    self.scale = qk_scale or head_dim**-0.5
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
         
     | 
| 360 | 
         
            +
                    # qkv and mlp_in
         
     | 
| 361 | 
         
            +
                    self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
         
     | 
| 362 | 
         
            +
                    # proj and mlp_out
         
     | 
| 363 | 
         
            +
                    self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    self.norm = QKNorm(head_dim)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 368 | 
         
            +
                    self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    self.mlp_act = nn.GELU(approximate="tanh")
         
     | 
| 371 | 
         
            +
                    self.modulation = Modulation(hidden_size, double=False)
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    processor = SingleStreamBlockProcessor()
         
     | 
| 374 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                def set_processor(self, processor) -> None:
         
     | 
| 378 | 
         
            +
                    self.processor = processor
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                def get_processor(self):
         
     | 
| 381 | 
         
            +
                    return self.processor
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                def forward(
         
     | 
| 384 | 
         
            +
                    self,
         
     | 
| 385 | 
         
            +
                    x: Tensor,
         
     | 
| 386 | 
         
            +
                    vec: Tensor,
         
     | 
| 387 | 
         
            +
                    pe: Tensor,
         
     | 
| 388 | 
         
            +
                    image_proj: Tensor | None = None,
         
     | 
| 389 | 
         
            +
                    ip_scale: float = 1.0
         
     | 
| 390 | 
         
            +
                ) -> Tensor:
         
     | 
| 391 | 
         
            +
                    if image_proj is None:
         
     | 
| 392 | 
         
            +
                        return self.processor(self, x, vec, pe)
         
     | 
| 393 | 
         
            +
                    else:
         
     | 
| 394 | 
         
            +
                        return self.processor(self, x, vec, pe, image_proj, ip_scale)
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
            class LastLayer(nn.Module):
         
     | 
| 398 | 
         
            +
                def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
         
     | 
| 399 | 
         
            +
                    super().__init__()
         
     | 
| 400 | 
         
            +
                    self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
         
     | 
| 401 | 
         
            +
                    self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
         
     | 
| 402 | 
         
            +
                    self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                def forward(self, x: Tensor, vec: Tensor) -> Tensor:
         
     | 
| 405 | 
         
            +
                    shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
         
     | 
| 406 | 
         
            +
                    x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
         
     | 
| 407 | 
         
            +
                    x = self.linear(x)
         
     | 
| 408 | 
         
            +
                    return x
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
            def get_norm_layer(norm_layer):
         
     | 
| 413 | 
         
            +
                """
         
     | 
| 414 | 
         
            +
                Get the normalization layer.
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                Args:
         
     | 
| 417 | 
         
            +
                    norm_layer (str): The type of normalization layer.
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                Returns:
         
     | 
| 420 | 
         
            +
                    norm_layer (nn.Module): The normalization layer.
         
     | 
| 421 | 
         
            +
                """
         
     | 
| 422 | 
         
            +
                if norm_layer == "layer":
         
     | 
| 423 | 
         
            +
                    return nn.LayerNorm
         
     | 
| 424 | 
         
            +
                elif norm_layer == "rms":
         
     | 
| 425 | 
         
            +
                    return RMSNorm
         
     | 
| 426 | 
         
            +
                else:
         
     | 
| 427 | 
         
            +
                    raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")   
         
     | 
| 428 | 
         
            +
              
         
     | 
| 429 | 
         
            +
            def get_activation_layer(act_type):
         
     | 
| 430 | 
         
            +
                """get activation layer
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                Args:
         
     | 
| 433 | 
         
            +
                    act_type (str): the activation type
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                Returns:
         
     | 
| 436 | 
         
            +
                    torch.nn.functional: the activation layer
         
     | 
| 437 | 
         
            +
                """
         
     | 
| 438 | 
         
            +
                if act_type == "gelu":
         
     | 
| 439 | 
         
            +
                    return lambda: nn.GELU()
         
     | 
| 440 | 
         
            +
                elif act_type == "gelu_tanh":
         
     | 
| 441 | 
         
            +
                    # Approximate `tanh` requires torch >= 1.13
         
     | 
| 442 | 
         
            +
                    return lambda: nn.GELU(approximate="tanh")
         
     | 
| 443 | 
         
            +
                elif act_type == "relu":
         
     | 
| 444 | 
         
            +
                    return nn.ReLU
         
     | 
| 445 | 
         
            +
                elif act_type == "silu":
         
     | 
| 446 | 
         
            +
                    return nn.SiLU
         
     | 
| 447 | 
         
            +
                else:
         
     | 
| 448 | 
         
            +
                    raise ValueError(f"Unknown activation type: {act_type}")
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
            def modulate(x, shift=None, scale=None):
         
     | 
| 451 | 
         
            +
                """modulate by shift and scale
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                Args:
         
     | 
| 454 | 
         
            +
                    x (torch.Tensor): input tensor.
         
     | 
| 455 | 
         
            +
                    shift (torch.Tensor, optional): shift tensor. Defaults to None.
         
     | 
| 456 | 
         
            +
                    scale (torch.Tensor, optional): scale tensor. Defaults to None.
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                Returns:
         
     | 
| 459 | 
         
            +
                    torch.Tensor: the output tensor after modulate.
         
     | 
| 460 | 
         
            +
                """
         
     | 
| 461 | 
         
            +
                if scale is None and shift is None:
         
     | 
| 462 | 
         
            +
                    return x
         
     | 
| 463 | 
         
            +
                elif shift is None:
         
     | 
| 464 | 
         
            +
                    return x * (1 + scale.unsqueeze(1))
         
     | 
| 465 | 
         
            +
                elif scale is None:
         
     | 
| 466 | 
         
            +
                    return x + shift.unsqueeze(1)
         
     | 
| 467 | 
         
            +
                else:
         
     | 
| 468 | 
         
            +
                    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
            def apply_gate(x, gate=None, tanh=False):
         
     | 
| 471 | 
         
            +
                """AI is creating summary for apply_gate
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                Args:
         
     | 
| 474 | 
         
            +
                    x (torch.Tensor): input tensor.
         
     | 
| 475 | 
         
            +
                    gate (torch.Tensor, optional): gate tensor. Defaults to None.
         
     | 
| 476 | 
         
            +
                    tanh (bool, optional): whether to use tanh function. Defaults to False.
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
                Returns:
         
     | 
| 479 | 
         
            +
                    torch.Tensor: the output tensor after apply gate.
         
     | 
| 480 | 
         
            +
                """
         
     | 
| 481 | 
         
            +
                if gate is None:
         
     | 
| 482 | 
         
            +
                    return x
         
     | 
| 483 | 
         
            +
                if tanh:
         
     | 
| 484 | 
         
            +
                    return x * gate.unsqueeze(1).tanh()
         
     | 
| 485 | 
         
            +
                else:
         
     | 
| 486 | 
         
            +
                    return x * gate.unsqueeze(1)
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
            class MLP(nn.Module):
         
     | 
| 489 | 
         
            +
                """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                def __init__(
         
     | 
| 492 | 
         
            +
                    self,
         
     | 
| 493 | 
         
            +
                    in_channels,
         
     | 
| 494 | 
         
            +
                    hidden_channels=None,
         
     | 
| 495 | 
         
            +
                    out_features=None,
         
     | 
| 496 | 
         
            +
                    act_layer=nn.GELU,
         
     | 
| 497 | 
         
            +
                    norm_layer=None,
         
     | 
| 498 | 
         
            +
                    bias=True,
         
     | 
| 499 | 
         
            +
                    drop=0.0,
         
     | 
| 500 | 
         
            +
                    use_conv=False,
         
     | 
| 501 | 
         
            +
                    device=None,
         
     | 
| 502 | 
         
            +
                    dtype=None,
         
     | 
| 503 | 
         
            +
                ):
         
     | 
| 504 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 505 | 
         
            +
                    super().__init__()
         
     | 
| 506 | 
         
            +
                    out_features = out_features or in_channels
         
     | 
| 507 | 
         
            +
                    hidden_channels = hidden_channels or in_channels
         
     | 
| 508 | 
         
            +
                    bias = to_2tuple(bias)
         
     | 
| 509 | 
         
            +
                    drop_probs = to_2tuple(drop)
         
     | 
| 510 | 
         
            +
                    linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                    self.fc1 = linear_layer(
         
     | 
| 513 | 
         
            +
                        in_channels, hidden_channels, bias=bias[0], **factory_kwargs
         
     | 
| 514 | 
         
            +
                    )
         
     | 
| 515 | 
         
            +
                    self.act = act_layer()
         
     | 
| 516 | 
         
            +
                    self.drop1 = nn.Dropout(drop_probs[0])
         
     | 
| 517 | 
         
            +
                    self.norm = (
         
     | 
| 518 | 
         
            +
                        norm_layer(hidden_channels, **factory_kwargs)
         
     | 
| 519 | 
         
            +
                        if norm_layer is not None
         
     | 
| 520 | 
         
            +
                        else nn.Identity()
         
     | 
| 521 | 
         
            +
                    )
         
     | 
| 522 | 
         
            +
                    self.fc2 = linear_layer(
         
     | 
| 523 | 
         
            +
                        hidden_channels, out_features, bias=bias[1], **factory_kwargs
         
     | 
| 524 | 
         
            +
                    )
         
     | 
| 525 | 
         
            +
                    self.drop2 = nn.Dropout(drop_probs[1])
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                def forward(self, x):
         
     | 
| 528 | 
         
            +
                    x = self.fc1(x)
         
     | 
| 529 | 
         
            +
                    x = self.act(x)
         
     | 
| 530 | 
         
            +
                    x = self.drop1(x)
         
     | 
| 531 | 
         
            +
                    x = self.norm(x)
         
     | 
| 532 | 
         
            +
                    x = self.fc2(x)
         
     | 
| 533 | 
         
            +
                    x = self.drop2(x)
         
     | 
| 534 | 
         
            +
                    return x
         
     | 
| 535 | 
         
            +
             
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
            class TextProjection(nn.Module):
         
     | 
| 554 | 
         
            +
                """
         
     | 
| 555 | 
         
            +
                Projects text embeddings. Also handles dropout for classifier-free guidance.
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
                Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
         
     | 
| 558 | 
         
            +
                """
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
                def __init__(self, in_channels, hidden_size, act_layer):
         
     | 
| 561 | 
         
            +
                    super().__init__()
         
     | 
| 562 | 
         
            +
                    self.linear_1 = nn.Linear(
         
     | 
| 563 | 
         
            +
                        in_features=in_channels,
         
     | 
| 564 | 
         
            +
                        out_features=hidden_size,
         
     | 
| 565 | 
         
            +
                        bias=True,    
         
     | 
| 566 | 
         
            +
                    )
         
     | 
| 567 | 
         
            +
                    self.act_1 = act_layer()
         
     | 
| 568 | 
         
            +
                    self.linear_2 = nn.Linear(
         
     | 
| 569 | 
         
            +
                        in_features=hidden_size,
         
     | 
| 570 | 
         
            +
                        out_features=hidden_size,
         
     | 
| 571 | 
         
            +
                        bias=True,
         
     | 
| 572 | 
         
            +
                    )
         
     | 
| 573 | 
         
            +
             
     | 
| 574 | 
         
            +
                def forward(self, caption):
         
     | 
| 575 | 
         
            +
                    hidden_states = self.linear_1(caption)
         
     | 
| 576 | 
         
            +
                    hidden_states = self.act_1(hidden_states)
         
     | 
| 577 | 
         
            +
                    hidden_states = self.linear_2(hidden_states)
         
     | 
| 578 | 
         
            +
                    return hidden_states
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
            def timestep_embedding_refiner(t, dim, max_period=10000):
         
     | 
| 582 | 
         
            +
                """
         
     | 
| 583 | 
         
            +
                Create sinusoidal timestep embeddings.
         
     | 
| 584 | 
         
            +
             
     | 
| 585 | 
         
            +
                Args:
         
     | 
| 586 | 
         
            +
                    t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
         
     | 
| 587 | 
         
            +
                    dim (int): the dimension of the output.
         
     | 
| 588 | 
         
            +
                    max_period (int): controls the minimum frequency of the embeddings.
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
                Returns:
         
     | 
| 591 | 
         
            +
                    embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
         
     | 
| 594 | 
         
            +
                """
         
     | 
| 595 | 
         
            +
                half = dim // 2
         
     | 
| 596 | 
         
            +
                freqs = torch.exp(
         
     | 
| 597 | 
         
            +
                    -math.log(max_period)
         
     | 
| 598 | 
         
            +
                    * torch.arange(start=0, end=half, dtype=torch.float32)
         
     | 
| 599 | 
         
            +
                    / half
         
     | 
| 600 | 
         
            +
                ).to(device=t.device)
         
     | 
| 601 | 
         
            +
                args = t[:, None].float() * freqs[None]
         
     | 
| 602 | 
         
            +
                embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         
     | 
| 603 | 
         
            +
                if dim % 2:
         
     | 
| 604 | 
         
            +
                    embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
         
     | 
| 605 | 
         
            +
                return embedding
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
            class TimestepEmbedder(nn.Module):
         
     | 
| 609 | 
         
            +
                """
         
     | 
| 610 | 
         
            +
                Embeds scalar timesteps into vector representations.
         
     | 
| 611 | 
         
            +
                """
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                def __init__(
         
     | 
| 614 | 
         
            +
                    self,
         
     | 
| 615 | 
         
            +
                    hidden_size,
         
     | 
| 616 | 
         
            +
                    act_layer,
         
     | 
| 617 | 
         
            +
                    frequency_embedding_size=256,
         
     | 
| 618 | 
         
            +
                    max_period=10000,
         
     | 
| 619 | 
         
            +
                    out_size=None,
         
     | 
| 620 | 
         
            +
                ):
         
     | 
| 621 | 
         
            +
                    super().__init__()
         
     | 
| 622 | 
         
            +
                    self.frequency_embedding_size = frequency_embedding_size
         
     | 
| 623 | 
         
            +
                    self.max_period = max_period
         
     | 
| 624 | 
         
            +
                    if out_size is None:
         
     | 
| 625 | 
         
            +
                        out_size = hidden_size
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                    self.mlp = nn.Sequential(
         
     | 
| 628 | 
         
            +
                        nn.Linear(
         
     | 
| 629 | 
         
            +
                            frequency_embedding_size, hidden_size, bias=True, 
         
     | 
| 630 | 
         
            +
                        ),
         
     | 
| 631 | 
         
            +
                        act_layer(),
         
     | 
| 632 | 
         
            +
                        nn.Linear(hidden_size, out_size, bias=True, ),
         
     | 
| 633 | 
         
            +
                    )
         
     | 
| 634 | 
         
            +
                    nn.init.normal_(self.mlp[0].weight, std=0.02)
         
     | 
| 635 | 
         
            +
                    nn.init.normal_(self.mlp[2].weight, std=0.02)
         
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
                def forward(self, t):
         
     | 
| 638 | 
         
            +
                    t_freq = timestep_embedding_refiner(
         
     | 
| 639 | 
         
            +
                        t, self.frequency_embedding_size, self.max_period
         
     | 
| 640 | 
         
            +
                    ).type(self.mlp[0].weight.dtype)
         
     | 
| 641 | 
         
            +
                    t_emb = self.mlp(t_freq)
         
     | 
| 642 | 
         
            +
                    return t_emb
         
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
            class IndividualTokenRefinerBlock(nn.Module):
         
     | 
| 646 | 
         
            +
                def __init__(
         
     | 
| 647 | 
         
            +
                    self,
         
     | 
| 648 | 
         
            +
                    hidden_size,
         
     | 
| 649 | 
         
            +
                    heads_num,
         
     | 
| 650 | 
         
            +
                    mlp_width_ratio: str = 4.0,
         
     | 
| 651 | 
         
            +
                    mlp_drop_rate: float = 0.0,
         
     | 
| 652 | 
         
            +
                    act_type: str = "silu",
         
     | 
| 653 | 
         
            +
                    qk_norm: bool = False,
         
     | 
| 654 | 
         
            +
                    qk_norm_type: str = "layer",
         
     | 
| 655 | 
         
            +
                    qkv_bias: bool = True,
         
     | 
| 656 | 
         
            +
                ):
         
     | 
| 657 | 
         
            +
                    super().__init__()
         
     | 
| 658 | 
         
            +
                    self.heads_num = heads_num
         
     | 
| 659 | 
         
            +
                    head_dim = hidden_size // heads_num
         
     | 
| 660 | 
         
            +
                    mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
         
     | 
| 661 | 
         
            +
             
     | 
| 662 | 
         
            +
                    self.norm1 = nn.LayerNorm(
         
     | 
| 663 | 
         
            +
                        hidden_size, elementwise_affine=True, eps=1e-6, 
         
     | 
| 664 | 
         
            +
                    )
         
     | 
| 665 | 
         
            +
                    self.self_attn_qkv = nn.Linear(
         
     | 
| 666 | 
         
            +
                        hidden_size, hidden_size * 3, bias=qkv_bias, 
         
     | 
| 667 | 
         
            +
                    )
         
     | 
| 668 | 
         
            +
                    qk_norm_layer = get_norm_layer(qk_norm_type)
         
     | 
| 669 | 
         
            +
                    self.self_attn_q_norm = (
         
     | 
| 670 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, )
         
     | 
| 671 | 
         
            +
                        if qk_norm
         
     | 
| 672 | 
         
            +
                        else nn.Identity()
         
     | 
| 673 | 
         
            +
                    )
         
     | 
| 674 | 
         
            +
                    self.self_attn_k_norm = (
         
     | 
| 675 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, )
         
     | 
| 676 | 
         
            +
                        if qk_norm
         
     | 
| 677 | 
         
            +
                        else nn.Identity()
         
     | 
| 678 | 
         
            +
                    )
         
     | 
| 679 | 
         
            +
                    self.self_attn_proj = nn.Linear(
         
     | 
| 680 | 
         
            +
                        hidden_size, hidden_size, bias=qkv_bias, 
         
     | 
| 681 | 
         
            +
                    )
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
                    self.norm2 = nn.LayerNorm(
         
     | 
| 684 | 
         
            +
                        hidden_size, elementwise_affine=True, eps=1e-6, 
         
     | 
| 685 | 
         
            +
                    )
         
     | 
| 686 | 
         
            +
                    act_layer = get_activation_layer(act_type)
         
     | 
| 687 | 
         
            +
                    self.mlp = MLP(
         
     | 
| 688 | 
         
            +
                        in_channels=hidden_size,
         
     | 
| 689 | 
         
            +
                        hidden_channels=mlp_hidden_dim,
         
     | 
| 690 | 
         
            +
                        act_layer=act_layer,
         
     | 
| 691 | 
         
            +
                        drop=mlp_drop_rate,
         
     | 
| 692 | 
         
            +
                    )
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                    self.adaLN_modulation = nn.Sequential(
         
     | 
| 695 | 
         
            +
                        act_layer(),
         
     | 
| 696 | 
         
            +
                        nn.Linear(hidden_size, 2 * hidden_size, bias=True, ),
         
     | 
| 697 | 
         
            +
                    )
         
     | 
| 698 | 
         
            +
                    # Zero-initialize the modulation
         
     | 
| 699 | 
         
            +
                    nn.init.zeros_(self.adaLN_modulation[1].weight)
         
     | 
| 700 | 
         
            +
                    nn.init.zeros_(self.adaLN_modulation[1].bias)
         
     | 
| 701 | 
         
            +
             
     | 
| 702 | 
         
            +
                def forward(
         
     | 
| 703 | 
         
            +
                    self,
         
     | 
| 704 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 705 | 
         
            +
                    c: torch.Tensor,  # timestep_aware_representations + context_aware_representations
         
     | 
| 706 | 
         
            +
                    attn_mask: torch.Tensor = None,
         
     | 
| 707 | 
         
            +
                ):
         
     | 
| 708 | 
         
            +
                    gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
         
     | 
| 709 | 
         
            +
             
     | 
| 710 | 
         
            +
                    norm_x = self.norm1(x)
         
     | 
| 711 | 
         
            +
                    qkv = self.self_attn_qkv(norm_x)
         
     | 
| 712 | 
         
            +
                    q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
         
     | 
| 713 | 
         
            +
                    # Apply QK-Norm if needed
         
     | 
| 714 | 
         
            +
                    q = self.self_attn_q_norm(q).to(v)
         
     | 
| 715 | 
         
            +
                    k = self.self_attn_k_norm(k).to(v)
         
     | 
| 716 | 
         
            +
             
     | 
| 717 | 
         
            +
                    # Self-Attention
         
     | 
| 718 | 
         
            +
                    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
         
     | 
| 719 | 
         
            +
                    attn = attention(q, k, v, attn_mask=attn_mask)
         
     | 
| 720 | 
         
            +
                    x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
         
     | 
| 721 | 
         
            +
             
     | 
| 722 | 
         
            +
                    # FFN Layer
         
     | 
| 723 | 
         
            +
                    x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
         
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
                    return x
         
     | 
| 726 | 
         
            +
             
     | 
| 727 | 
         
            +
             
     | 
| 728 | 
         
            +
            class CrossTokenRefinerBlock(nn.Module):
         
     | 
| 729 | 
         
            +
                def __init__(
         
     | 
| 730 | 
         
            +
                    self,
         
     | 
| 731 | 
         
            +
                    hidden_size,
         
     | 
| 732 | 
         
            +
                    heads_num,
         
     | 
| 733 | 
         
            +
                    mlp_width_ratio: str = 4.0,
         
     | 
| 734 | 
         
            +
                    mlp_drop_rate: float = 0.0,
         
     | 
| 735 | 
         
            +
                    act_type: str = "silu",
         
     | 
| 736 | 
         
            +
                    qk_norm: bool = False,
         
     | 
| 737 | 
         
            +
                    qk_norm_type: str = "layer",
         
     | 
| 738 | 
         
            +
                    qkv_bias: bool = True,
         
     | 
| 739 | 
         
            +
                ):
         
     | 
| 740 | 
         
            +
                    super().__init__()
         
     | 
| 741 | 
         
            +
                    self.heads_num = heads_num
         
     | 
| 742 | 
         
            +
                    head_dim = hidden_size // heads_num
         
     | 
| 743 | 
         
            +
                    mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
         
     | 
| 744 | 
         
            +
             
     | 
| 745 | 
         
            +
                    self.norm1 = nn.LayerNorm(
         
     | 
| 746 | 
         
            +
                        hidden_size, elementwise_affine=True, eps=1e-6, 
         
     | 
| 747 | 
         
            +
                    )
         
     | 
| 748 | 
         
            +
                    self.self_attn_q = nn.Linear(
         
     | 
| 749 | 
         
            +
                        hidden_size, hidden_size, bias=qkv_bias, 
         
     | 
| 750 | 
         
            +
                    )
         
     | 
| 751 | 
         
            +
                    self.norm_y = nn.LayerNorm(
         
     | 
| 752 | 
         
            +
                        hidden_size, elementwise_affine=True, eps=1e-6, 
         
     | 
| 753 | 
         
            +
                    )
         
     | 
| 754 | 
         
            +
                    self.self_attn_kv = nn.Linear(
         
     | 
| 755 | 
         
            +
                        hidden_size, hidden_size*2, bias=qkv_bias, 
         
     | 
| 756 | 
         
            +
                    )
         
     | 
| 757 | 
         
            +
                    qk_norm_layer = get_norm_layer(qk_norm_type)
         
     | 
| 758 | 
         
            +
                    self.self_attn_q_norm = (
         
     | 
| 759 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, )
         
     | 
| 760 | 
         
            +
                        if qk_norm
         
     | 
| 761 | 
         
            +
                        else nn.Identity()
         
     | 
| 762 | 
         
            +
                    )
         
     | 
| 763 | 
         
            +
                    self.self_attn_k_norm = (
         
     | 
| 764 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, )
         
     | 
| 765 | 
         
            +
                        if qk_norm
         
     | 
| 766 | 
         
            +
                        else nn.Identity()
         
     | 
| 767 | 
         
            +
                    )
         
     | 
| 768 | 
         
            +
                    self.self_attn_proj = nn.Linear(
         
     | 
| 769 | 
         
            +
                        hidden_size, hidden_size, bias=qkv_bias, 
         
     | 
| 770 | 
         
            +
                    )
         
     | 
| 771 | 
         
            +
             
     | 
| 772 | 
         
            +
                    self.norm2 = nn.LayerNorm(
         
     | 
| 773 | 
         
            +
                        hidden_size, elementwise_affine=True, eps=1e-6, 
         
     | 
| 774 | 
         
            +
                    )
         
     | 
| 775 | 
         
            +
                    act_layer = get_activation_layer(act_type)
         
     | 
| 776 | 
         
            +
                    self.mlp = MLP(
         
     | 
| 777 | 
         
            +
                        in_channels=hidden_size,
         
     | 
| 778 | 
         
            +
                        hidden_channels=mlp_hidden_dim,
         
     | 
| 779 | 
         
            +
                        act_layer=act_layer,
         
     | 
| 780 | 
         
            +
                        drop=mlp_drop_rate,
         
     | 
| 781 | 
         
            +
                    )
         
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
                    self.adaLN_modulation = nn.Sequential(
         
     | 
| 784 | 
         
            +
                        act_layer(),
         
     | 
| 785 | 
         
            +
                        nn.Linear(hidden_size, 2 * hidden_size, bias=True, ),
         
     | 
| 786 | 
         
            +
                    )
         
     | 
| 787 | 
         
            +
                    # Zero-initialize the modulation
         
     | 
| 788 | 
         
            +
                    nn.init.zeros_(self.adaLN_modulation[1].weight)
         
     | 
| 789 | 
         
            +
                    nn.init.zeros_(self.adaLN_modulation[1].bias)
         
     | 
| 790 | 
         
            +
             
     | 
| 791 | 
         
            +
                def forward(
         
     | 
| 792 | 
         
            +
                    self,
         
     | 
| 793 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 794 | 
         
            +
                    y: torch.Tensor,
         
     | 
| 795 | 
         
            +
                    c: torch.Tensor,  # timestep_aware_representations + context_aware_representations
         
     | 
| 796 | 
         
            +
                    attn_mask: torch.Tensor = None,
         
     | 
| 797 | 
         
            +
                ):
         
     | 
| 798 | 
         
            +
                    gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
         
     | 
| 799 | 
         
            +
             
     | 
| 800 | 
         
            +
                    norm_x = self.norm1(x)
         
     | 
| 801 | 
         
            +
                    q = self.self_attn_q(norm_x)
         
     | 
| 802 | 
         
            +
                    q = rearrange(qkv, "B L (H D) -> B L H D", H=self.heads_num)
         
     | 
| 803 | 
         
            +
                    norm_y = self.norm_y(y)
         
     | 
| 804 | 
         
            +
                    kv = self.self_attn_kv(norm_y)
         
     | 
| 805 | 
         
            +
                    k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
         
     | 
| 806 | 
         
            +
                    # Apply QK-Norm if needed
         
     | 
| 807 | 
         
            +
                    q = self.self_attn_q_norm(q).to(v)
         
     | 
| 808 | 
         
            +
                    k = self.self_attn_k_norm(k).to(v)
         
     | 
| 809 | 
         
            +
             
     | 
| 810 | 
         
            +
                    # Self-Attention
         
     | 
| 811 | 
         
            +
                    attn = attention(q, k, v, attn_mask=attn_mask)
         
     | 
| 812 | 
         
            +
                    x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
         
     | 
| 813 | 
         
            +
             
     | 
| 814 | 
         
            +
                    # FFN Layer
         
     | 
| 815 | 
         
            +
                    x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
         
     | 
| 816 | 
         
            +
             
     | 
| 817 | 
         
            +
                    return x
         
     | 
| 818 | 
         
            +
             
     | 
| 819 | 
         
            +
            class IndividualTokenRefiner(nn.Module):
         
     | 
| 820 | 
         
            +
                def __init__(
         
     | 
| 821 | 
         
            +
                    self,
         
     | 
| 822 | 
         
            +
                    hidden_size,
         
     | 
| 823 | 
         
            +
                    heads_num,
         
     | 
| 824 | 
         
            +
                    depth,
         
     | 
| 825 | 
         
            +
                    mlp_width_ratio: float = 4.0,
         
     | 
| 826 | 
         
            +
                    mlp_drop_rate: float = 0.0,
         
     | 
| 827 | 
         
            +
                    act_type: str = "silu",
         
     | 
| 828 | 
         
            +
                    qk_norm: bool = False,
         
     | 
| 829 | 
         
            +
                    qk_norm_type: str = "layer",
         
     | 
| 830 | 
         
            +
                    qkv_bias: bool = True,
         
     | 
| 831 | 
         
            +
                ):
         
     | 
| 832 | 
         
            +
                    super().__init__()
         
     | 
| 833 | 
         
            +
                    self.blocks = nn.ModuleList(
         
     | 
| 834 | 
         
            +
                        [
         
     | 
| 835 | 
         
            +
                            IndividualTokenRefinerBlock(
         
     | 
| 836 | 
         
            +
                                hidden_size=hidden_size,
         
     | 
| 837 | 
         
            +
                                heads_num=heads_num,
         
     | 
| 838 | 
         
            +
                                mlp_width_ratio=mlp_width_ratio,
         
     | 
| 839 | 
         
            +
                                mlp_drop_rate=mlp_drop_rate,
         
     | 
| 840 | 
         
            +
                                act_type=act_type,
         
     | 
| 841 | 
         
            +
                                qk_norm=qk_norm,
         
     | 
| 842 | 
         
            +
                                qk_norm_type=qk_norm_type,
         
     | 
| 843 | 
         
            +
                                qkv_bias=qkv_bias,
         
     | 
| 844 | 
         
            +
                            )
         
     | 
| 845 | 
         
            +
                            for _ in range(depth)
         
     | 
| 846 | 
         
            +
                        ]
         
     | 
| 847 | 
         
            +
                    )
         
     | 
| 848 | 
         
            +
             
     | 
| 849 | 
         
            +
                def forward(
         
     | 
| 850 | 
         
            +
                    self,
         
     | 
| 851 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 852 | 
         
            +
                    c: torch.LongTensor,
         
     | 
| 853 | 
         
            +
                    mask: Optional[torch.Tensor] = None,
         
     | 
| 854 | 
         
            +
                ):
         
     | 
| 855 | 
         
            +
                    self_attn_mask = None
         
     | 
| 856 | 
         
            +
                    if mask is not None:
         
     | 
| 857 | 
         
            +
                        batch_size = mask.shape[0]
         
     | 
| 858 | 
         
            +
                        seq_len = mask.shape[1]
         
     | 
| 859 | 
         
            +
                        mask = mask.to(x.device)
         
     | 
| 860 | 
         
            +
                        # batch_size x 1 x seq_len x seq_len
         
     | 
| 861 | 
         
            +
                        self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
         
     | 
| 862 | 
         
            +
                            1, 1, seq_len, 1
         
     | 
| 863 | 
         
            +
                        )
         
     | 
| 864 | 
         
            +
                        # batch_size x 1 x seq_len x seq_len
         
     | 
| 865 | 
         
            +
                        self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
         
     | 
| 866 | 
         
            +
                        # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
         
     | 
| 867 | 
         
            +
                        self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
         
     | 
| 868 | 
         
            +
                        # avoids self-attention weight being NaN for padding tokens
         
     | 
| 869 | 
         
            +
                        self_attn_mask[:, :, :, 0] = True
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
                    for block in self.blocks:
         
     | 
| 872 | 
         
            +
                        x = block(x, c, self_attn_mask)
         
     | 
| 873 | 
         
            +
                    return x
         
     | 
| 874 | 
         
            +
             
     | 
| 875 | 
         
            +
             
     | 
| 876 | 
         
            +
            class SingleTokenRefiner(nn.Module):
         
     | 
| 877 | 
         
            +
                """
         
     | 
| 878 | 
         
            +
                A single token refiner block for llm text embedding refine.
         
     | 
| 879 | 
         
            +
                """
         
     | 
| 880 | 
         
            +
                def __init__(
         
     | 
| 881 | 
         
            +
                    self,
         
     | 
| 882 | 
         
            +
                    in_channels,
         
     | 
| 883 | 
         
            +
                    hidden_size,
         
     | 
| 884 | 
         
            +
                    heads_num,
         
     | 
| 885 | 
         
            +
                    depth,
         
     | 
| 886 | 
         
            +
                    mlp_width_ratio: float = 4.0,
         
     | 
| 887 | 
         
            +
                    mlp_drop_rate: float = 0.0,
         
     | 
| 888 | 
         
            +
                    act_type: str = "silu",
         
     | 
| 889 | 
         
            +
                    qk_norm: bool = False,
         
     | 
| 890 | 
         
            +
                    qk_norm_type: str = "layer",
         
     | 
| 891 | 
         
            +
                    qkv_bias: bool = True,
         
     | 
| 892 | 
         
            +
                    attn_mode: str = "torch",
         
     | 
| 893 | 
         
            +
                    enable_cls_token: bool = False,
         
     | 
| 894 | 
         
            +
                    enable_cross_attn: bool = False,
         
     | 
| 895 | 
         
            +
                    length: int = 29,
         
     | 
| 896 | 
         
            +
                ):
         
     | 
| 897 | 
         
            +
                    super().__init__()
         
     | 
| 898 | 
         
            +
                    self.attn_mode = attn_mode
         
     | 
| 899 | 
         
            +
                    assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
         
     | 
| 900 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 901 | 
         
            +
                    self.enable_cross_attn = enable_cross_attn
         
     | 
| 902 | 
         
            +
                    if self.enable_cross_attn:
         
     | 
| 903 | 
         
            +
                        self.length = length
         
     | 
| 904 | 
         
            +
                        self.input_embedder = nn.Linear(
         
     | 
| 905 | 
         
            +
                            in_channels//length, hidden_size, bias=True, 
         
     | 
| 906 | 
         
            +
                        )
         
     | 
| 907 | 
         
            +
                        self.kv_embedder = nn.Linear(
         
     | 
| 908 | 
         
            +
                            in_channels//length*(length-1), hidden_size, bias=True, 
         
     | 
| 909 | 
         
            +
                        )
         
     | 
| 910 | 
         
            +
                        self.fusion = CrossTokenRefinerBlock(
         
     | 
| 911 | 
         
            +
                                hidden_size=hidden_size,
         
     | 
| 912 | 
         
            +
                                heads_num=heads_num,
         
     | 
| 913 | 
         
            +
                                mlp_width_ratio=mlp_width_ratio,
         
     | 
| 914 | 
         
            +
                                mlp_drop_rate=mlp_drop_rate,
         
     | 
| 915 | 
         
            +
                                act_type=act_type,
         
     | 
| 916 | 
         
            +
                                qk_norm=qk_norm,
         
     | 
| 917 | 
         
            +
                                qk_norm_type=qk_norm_type,
         
     | 
| 918 | 
         
            +
                                qkv_bias=qkv_bias,
         
     | 
| 919 | 
         
            +
                            )
         
     | 
| 920 | 
         
            +
                    else:
         
     | 
| 921 | 
         
            +
                        self.input_embedder = nn.Linear(
         
     | 
| 922 | 
         
            +
                            in_channels, hidden_size, bias=True, 
         
     | 
| 923 | 
         
            +
                        )
         
     | 
| 924 | 
         
            +
             
     | 
| 925 | 
         
            +
                    act_layer = get_activation_layer(act_type)
         
     | 
| 926 | 
         
            +
                    # Build timestep embedding layer
         
     | 
| 927 | 
         
            +
                    # self.t_embedder = TimestepEmbedder(hidden_size, act_layer,)
         
     | 
| 928 | 
         
            +
                    # Build context embedding layer
         
     | 
| 929 | 
         
            +
                    self.c_embedder = TextProjection(
         
     | 
| 930 | 
         
            +
                        in_channels, hidden_size, act_layer, 
         
     | 
| 931 | 
         
            +
                    )
         
     | 
| 932 | 
         
            +
             
     | 
| 933 | 
         
            +
                    self.individual_token_refiner = IndividualTokenRefiner(
         
     | 
| 934 | 
         
            +
                        hidden_size=hidden_size,
         
     | 
| 935 | 
         
            +
                        heads_num=heads_num,
         
     | 
| 936 | 
         
            +
                        depth=depth,
         
     | 
| 937 | 
         
            +
                        mlp_width_ratio=mlp_width_ratio,
         
     | 
| 938 | 
         
            +
                        mlp_drop_rate=mlp_drop_rate,
         
     | 
| 939 | 
         
            +
                        act_type=act_type,
         
     | 
| 940 | 
         
            +
                        qk_norm=qk_norm,
         
     | 
| 941 | 
         
            +
                        qk_norm_type=qk_norm_type,
         
     | 
| 942 | 
         
            +
                        qkv_bias=qkv_bias,
         
     | 
| 943 | 
         
            +
                    )
         
     | 
| 944 | 
         
            +
             
     | 
| 945 | 
         
            +
                    self.enable_cls_token = enable_cls_token
         
     | 
| 946 | 
         
            +
                    if self.enable_cls_token:
         
     | 
| 947 | 
         
            +
                        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
         
     | 
| 948 | 
         
            +
                        nn.init.normal_(self.cls_token, std=1e-6)
         
     | 
| 949 | 
         
            +
             
     | 
| 950 | 
         
            +
                def forward(
         
     | 
| 951 | 
         
            +
                    self,
         
     | 
| 952 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 953 | 
         
            +
                    mask: Optional[torch.LongTensor] = None,
         
     | 
| 954 | 
         
            +
                ):
         
     | 
| 955 | 
         
            +
                    if mask is None:
         
     | 
| 956 | 
         
            +
                        context_aware_representations = x.mean(dim=1)
         
     | 
| 957 | 
         
            +
                    else:
         
     | 
| 958 | 
         
            +
                        mask_float = mask.float().unsqueeze(-1)  # [b, s1, 1]
         
     | 
| 959 | 
         
            +
                        context_aware_representations = (x * mask_float).sum(
         
     | 
| 960 | 
         
            +
                            dim=1
         
     | 
| 961 | 
         
            +
                        ) / mask_float.sum(dim=1)
         
     | 
| 962 | 
         
            +
                    c = self.c_embedder(context_aware_representations)
         
     | 
| 963 | 
         
            +
                    if self.enable_cross_attn:
         
     | 
| 964 | 
         
            +
                        single_channels = self.in_channels // self.length
         
     | 
| 965 | 
         
            +
                        x, y = torch.split(x, [single_channels, single_channels*(self.length-1)], dim=-1)
         
     | 
| 966 | 
         
            +
                        x = self.input_embedder(x)
         
     | 
| 967 | 
         
            +
                        y = self.kv_embedder(y)
         
     | 
| 968 | 
         
            +
                    else:
         
     | 
| 969 | 
         
            +
                        x = self.input_embedder(x)
         
     | 
| 970 | 
         
            +
                    if self.enable_cls_token:
         
     | 
| 971 | 
         
            +
                        B, L, C = x.shape
         
     | 
| 972 | 
         
            +
                        x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1)
         
     | 
| 973 | 
         
            +
                    
         
     | 
| 974 | 
         
            +
                    if self.enable_cross_attn:
         
     | 
| 975 | 
         
            +
                        x = self.fusion(x, y, c)
         
     | 
| 976 | 
         
            +
                    x = self.individual_token_refiner(x, c, mask)
         
     | 
| 977 | 
         
            +
                    if self.enable_cls_token:
         
     | 
| 978 | 
         
            +
                        x_global = x[:, 0]
         
     | 
| 979 | 
         
            +
                        x = x[:, 1:]
         
     | 
| 980 | 
         
            +
                    else:
         
     | 
| 981 | 
         
            +
                        x_global = x.mean(dim=1)
         
     | 
| 982 | 
         
            +
                    return dict(
         
     | 
| 983 | 
         
            +
                        txt_fea=x,
         
     | 
| 984 | 
         
            +
                        txt_fea_avg=x_global
         
     | 
| 985 | 
         
            +
                    )
         
     | 
| 986 | 
         
            +
             
     | 
| 987 | 
         
            +
             
     | 
| 988 | 
         
            +
             
     | 
| 989 | 
         
            +
             
     | 
| 990 | 
         
            +
             
     | 
| 991 | 
         
            +
             
     | 
| 992 | 
         
            +
             
     | 
| 993 | 
         
            +
             
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
             
     | 
| 996 | 
         
            +
             
     | 
| 997 | 
         
            +
             
     | 
| 998 | 
         
            +
             
     | 
| 999 | 
         
            +
             
     | 
| 1000 | 
         
            +
             
     | 
| 1001 | 
         
            +
             
     | 
| 1002 | 
         
            +
            __all__ = ["YakModel"]
         
     | 
| 1003 | 
         
            +
             
     | 
| 1004 | 
         
            +
            @dataclass
         
     | 
| 1005 | 
         
            +
            class VisualGeneratorOutput(ModelOutput):
         
     | 
| 1006 | 
         
            +
                loss: Optional[torch.FloatTensor] = None
         
     | 
| 1007 | 
         
            +
             
     | 
| 1008 | 
         
            +
             
     | 
| 1009 | 
         
            +
            class YakTransformer(nn.Module):
         
     | 
| 1010 | 
         
            +
                def __init__(self, config: YakConfig):
         
     | 
| 1011 | 
         
            +
                    super().__init__()
         
     | 
| 1012 | 
         
            +
                    self.config = config
         
     | 
| 1013 | 
         
            +
                    self.in_channels = config.in_channels
         
     | 
| 1014 | 
         
            +
                    self.out_channels = config.out_channels
         
     | 
| 1015 | 
         
            +
                    if config.hidden_size % config.num_heads != 0:
         
     | 
| 1016 | 
         
            +
                        raise ValueError(
         
     | 
| 1017 | 
         
            +
                            f"Hidden size {config.hidden_size} must be divisible by num_heads {config.num_heads}"
         
     | 
| 1018 | 
         
            +
                        )
         
     | 
| 1019 | 
         
            +
                    pe_dim = config.hidden_size // config.num_heads
         
     | 
| 1020 | 
         
            +
                    if sum(config.axes_dim) != pe_dim:
         
     | 
| 1021 | 
         
            +
                        raise ValueError(f"Got {config.axes_dim} but expected positional dim {pe_dim}")
         
     | 
| 1022 | 
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 1023 | 
         
            +
                    self.num_heads = config.num_heads
         
     | 
| 1024 | 
         
            +
                    self.pe_embedder = EmbedND(dim=pe_dim, theta=config.theta, axes_dim=config.axes_dim)
         
     | 
| 1025 | 
         
            +
                    self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
         
     | 
| 1026 | 
         
            +
                    self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
         
     | 
| 1027 | 
         
            +
                    self.vector_in = MLPEmbedder(config.vec_in_dim, self.hidden_size)
         
     | 
| 1028 | 
         
            +
                    self.guidance_in = (
         
     | 
| 1029 | 
         
            +
                        MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if config.guidance_embed else nn.Identity()
         
     | 
| 1030 | 
         
            +
                    )
         
     | 
| 1031 | 
         
            +
                    self.txt_type = config.txt_type
         
     | 
| 1032 | 
         
            +
                    self.txt_in = SingleTokenRefiner(
         
     | 
| 1033 | 
         
            +
                        config.context_in_dim, 
         
     | 
| 1034 | 
         
            +
                        self.hidden_size, 
         
     | 
| 1035 | 
         
            +
                        heads_num=config.num_heads * 2, 
         
     | 
| 1036 | 
         
            +
                        depth=2, 
         
     | 
| 1037 | 
         
            +
                        enable_cls_token=True
         
     | 
| 1038 | 
         
            +
                    )
         
     | 
| 1039 | 
         
            +
             
     | 
| 1040 | 
         
            +
                    self.double_blocks = nn.ModuleList(
         
     | 
| 1041 | 
         
            +
                        [
         
     | 
| 1042 | 
         
            +
                            DoubleStreamXBlock(
         
     | 
| 1043 | 
         
            +
                                self.hidden_size,
         
     | 
| 1044 | 
         
            +
                                self.num_heads,
         
     | 
| 1045 | 
         
            +
                                mlp_ratio=config.mlp_ratio,
         
     | 
| 1046 | 
         
            +
                                qkv_bias=config.qkv_bias,
         
     | 
| 1047 | 
         
            +
                            )
         
     | 
| 1048 | 
         
            +
                            for _ in range(config.depth)
         
     | 
| 1049 | 
         
            +
                        ]
         
     | 
| 1050 | 
         
            +
                    )
         
     | 
| 1051 | 
         
            +
             
     | 
| 1052 | 
         
            +
                    self.single_blocks = nn.ModuleList(
         
     | 
| 1053 | 
         
            +
                        [
         
     | 
| 1054 | 
         
            +
                            SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=config.mlp_ratio)
         
     | 
| 1055 | 
         
            +
                            for _ in range(config.depth_single_blocks)
         
     | 
| 1056 | 
         
            +
                        ]
         
     | 
| 1057 | 
         
            +
                    )
         
     | 
| 1058 | 
         
            +
             
     | 
| 1059 | 
         
            +
                    self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
         
     | 
| 1060 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 1061 | 
         
            +
             
     | 
| 1062 | 
         
            +
                def forward(
         
     | 
| 1063 | 
         
            +
                    self,
         
     | 
| 1064 | 
         
            +
                    img: Tensor,
         
     | 
| 1065 | 
         
            +
                    img_ids: Tensor,
         
     | 
| 1066 | 
         
            +
                    txt: Tensor,
         
     | 
| 1067 | 
         
            +
                    txt_ids: Tensor,
         
     | 
| 1068 | 
         
            +
                    timesteps: Tensor,
         
     | 
| 1069 | 
         
            +
                    guidance: Tensor | None = None,
         
     | 
| 1070 | 
         
            +
                    cond_img: Tensor = None,
         
     | 
| 1071 | 
         
            +
                    cond_img_ids: Tensor = None,
         
     | 
| 1072 | 
         
            +
                ):
         
     | 
| 1073 | 
         
            +
                    if img.ndim != 3 or txt.ndim != 3:
         
     | 
| 1074 | 
         
            +
                        raise ValueError("Input img and txt tensors must have 3 dimensions.")
         
     | 
| 1075 | 
         
            +
             
     | 
| 1076 | 
         
            +
                    # running on sequences img
         
     | 
| 1077 | 
         
            +
                    img_tokens = img.shape[1]
         
     | 
| 1078 | 
         
            +
                    if cond_img is not None:
         
     | 
| 1079 | 
         
            +
                        img = torch.cat([img, cond_img], dim=1)
         
     | 
| 1080 | 
         
            +
                        img_ids = torch.cat([img_ids, cond_img_ids], dim=1)
         
     | 
| 1081 | 
         
            +
                    img = self.img_in(img)
         
     | 
| 1082 | 
         
            +
             
     | 
| 1083 | 
         
            +
                    vec = self.time_in(timestep_embedding(timesteps, 256))
         
     | 
| 1084 | 
         
            +
                    if self.config.guidance_embed:
         
     | 
| 1085 | 
         
            +
                        if guidance is None:
         
     | 
| 1086 | 
         
            +
                            raise ValueError("Didn't get guidance strength for guidance distilled model.")
         
     | 
| 1087 | 
         
            +
                        vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
         
     | 
| 1088 | 
         
            +
                    txt_dict = self.txt_in(txt)
         
     | 
| 1089 | 
         
            +
                    txt = txt_dict["txt_fea"]
         
     | 
| 1090 | 
         
            +
                    y = txt_dict["txt_fea_avg"]
         
     | 
| 1091 | 
         
            +
                    vec = vec + self.vector_in(y)
         
     | 
| 1092 | 
         
            +
             
     | 
| 1093 | 
         
            +
                    ids = torch.cat((txt_ids, img_ids), dim=1)
         
     | 
| 1094 | 
         
            +
                    pe = self.pe_embedder(ids)
         
     | 
| 1095 | 
         
            +
             
     | 
| 1096 | 
         
            +
                    for block in self.double_blocks:
         
     | 
| 1097 | 
         
            +
                        if self.training and self.gradient_checkpointing:
         
     | 
| 1098 | 
         
            +
                            img, txt = self._gradient_checkpointing_func(
         
     | 
| 1099 | 
         
            +
                                block.__call__,
         
     | 
| 1100 | 
         
            +
                                img,
         
     | 
| 1101 | 
         
            +
                                txt,
         
     | 
| 1102 | 
         
            +
                                vec,
         
     | 
| 1103 | 
         
            +
                                pe,
         
     | 
| 1104 | 
         
            +
                            )
         
     | 
| 1105 | 
         
            +
                        else:
         
     | 
| 1106 | 
         
            +
                            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
         
     | 
| 1107 | 
         
            +
             
     | 
| 1108 | 
         
            +
                    img = torch.cat((txt, img), 1)
         
     | 
| 1109 | 
         
            +
                    for block in self.single_blocks:
         
     | 
| 1110 | 
         
            +
                        if self.training and self.gradient_checkpointing:
         
     | 
| 1111 | 
         
            +
                            img = self._gradient_checkpointing_func(
         
     | 
| 1112 | 
         
            +
                                block.__call__,
         
     | 
| 1113 | 
         
            +
                                img,
         
     | 
| 1114 | 
         
            +
                                vec,
         
     | 
| 1115 | 
         
            +
                                pe,
         
     | 
| 1116 | 
         
            +
                            )
         
     | 
| 1117 | 
         
            +
                        else:
         
     | 
| 1118 | 
         
            +
                            img = block(img, vec=vec, pe=pe)
         
     | 
| 1119 | 
         
            +
                    img = img[:, txt.shape[1] :, ...]
         
     | 
| 1120 | 
         
            +
             
     | 
| 1121 | 
         
            +
                    img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
         
     | 
| 1122 | 
         
            +
                    if cond_img is not None:
         
     | 
| 1123 | 
         
            +
                        img = torch.split(img, img_tokens, dim=1)[0]
         
     | 
| 1124 | 
         
            +
                    return img
         
     | 
| 1125 | 
         
            +
             
     | 
| 1126 | 
         
            +
            def time_shift(mu: float, sigma: float, t: Tensor):
         
     | 
| 1127 | 
         
            +
                return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
         
     | 
| 1128 | 
         
            +
             
     | 
| 1129 | 
         
            +
             
     | 
| 1130 | 
         
            +
            def get_lin_function(
         
     | 
| 1131 | 
         
            +
                x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
         
     | 
| 1132 | 
         
            +
            ) -> Callable[[float], float]:
         
     | 
| 1133 | 
         
            +
                m = (y2 - y1) / (x2 - x1)
         
     | 
| 1134 | 
         
            +
                b = y1 - m * x1
         
     | 
| 1135 | 
         
            +
                return lambda x: m * x + b
         
     | 
| 1136 | 
         
            +
             
     | 
| 1137 | 
         
            +
            def get_noise(
         
     | 
| 1138 | 
         
            +
                num_samples: int,
         
     | 
| 1139 | 
         
            +
                channel: int,
         
     | 
| 1140 | 
         
            +
                height: int,
         
     | 
| 1141 | 
         
            +
                width: int,
         
     | 
| 1142 | 
         
            +
                device: torch.device,
         
     | 
| 1143 | 
         
            +
                dtype: torch.dtype,
         
     | 
| 1144 | 
         
            +
                seed: int,
         
     | 
| 1145 | 
         
            +
            ):
         
     | 
| 1146 | 
         
            +
                return torch.randn(
         
     | 
| 1147 | 
         
            +
                    num_samples,
         
     | 
| 1148 | 
         
            +
                    channel,
         
     | 
| 1149 | 
         
            +
                    # allow for packing
         
     | 
| 1150 | 
         
            +
                    2 * math.ceil(height / 16),
         
     | 
| 1151 | 
         
            +
                    2 * math.ceil(width / 16),
         
     | 
| 1152 | 
         
            +
                    device=device,
         
     | 
| 1153 | 
         
            +
                    dtype=dtype,
         
     | 
| 1154 | 
         
            +
                    generator=torch.Generator(device=device).manual_seed(seed),
         
     | 
| 1155 | 
         
            +
                )
         
     | 
| 1156 | 
         
            +
             
     | 
| 1157 | 
         
            +
            def unpack(x: Tensor, height: int, width: int) -> Tensor:
         
     | 
| 1158 | 
         
            +
                return rearrange(
         
     | 
| 1159 | 
         
            +
                    x,
         
     | 
| 1160 | 
         
            +
                    "b (h w) (c ph pw) -> b c (h ph) (w pw)",
         
     | 
| 1161 | 
         
            +
                    h=math.ceil(height / 16),
         
     | 
| 1162 | 
         
            +
                    w=math.ceil(width / 16),
         
     | 
| 1163 | 
         
            +
                    ph=2,
         
     | 
| 1164 | 
         
            +
                    pw=2,
         
     | 
| 1165 | 
         
            +
                )
         
     | 
| 1166 | 
         
            +
             
     | 
| 1167 | 
         
            +
            class YakPretrainedModel(PreTrainedModel):
         
     | 
| 1168 | 
         
            +
                config_class = YakConfig
         
     | 
| 1169 | 
         
            +
                base_model_prefix = "yak"
         
     | 
| 1170 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 1171 | 
         
            +
                main_input_name = "pixel_values"
         
     | 
| 1172 | 
         
            +
                _supports_sdpa = True
         
     | 
| 1173 | 
         
            +
             
     | 
| 1174 | 
         
            +
             
     | 
| 1175 | 
         
            +
            class YakModel(YakPretrainedModel):
         
     | 
| 1176 | 
         
            +
                def __init__(self, config: YakConfig):
         
     | 
| 1177 | 
         
            +
                    super().__init__(config)
         
     | 
| 1178 | 
         
            +
                    self.vae = AutoencoderKL.from_config(config.vae_config)
         
     | 
| 1179 | 
         
            +
                    self.backbone = YakTransformer(config)
         
     | 
| 1180 | 
         
            +
             
     | 
| 1181 | 
         
            +
                def get_refiner(self):
         
     | 
| 1182 | 
         
            +
                    return self.backbone.txt_in
         
     | 
| 1183 | 
         
            +
                
         
     | 
| 1184 | 
         
            +
                def get_cls_refiner(self):
         
     | 
| 1185 | 
         
            +
                    return self.backbone.vector_in
         
     | 
| 1186 | 
         
            +
             
     | 
| 1187 | 
         
            +
                def get_backbone(self):
         
     | 
| 1188 | 
         
            +
                    return self.backbone
         
     | 
| 1189 | 
         
            +
             
     | 
| 1190 | 
         
            +
                def get_vae(self):
         
     | 
| 1191 | 
         
            +
                    return self.vae
         
     | 
| 1192 | 
         
            +
             
     | 
| 1193 | 
         
            +
                def preprocess_image(self, image: Image.Image, size, convert_to_rgb=True, Norm=True, output_type="tensor"):
         
     | 
| 1194 | 
         
            +
                    image = exif_transpose(image)
         
     | 
| 1195 | 
         
            +
                    if not image.mode == "RGB" and convert_to_rgb:
         
     | 
| 1196 | 
         
            +
                        image = image.convert("RGB")
         
     | 
| 1197 | 
         
            +
             
     | 
| 1198 | 
         
            +
                    image = torchvision.transforms.functional.resize(
         
     | 
| 1199 | 
         
            +
                        image, size, interpolation=transforms.InterpolationMode.BICUBIC
         
     | 
| 1200 | 
         
            +
                    )
         
     | 
| 1201 | 
         
            +
             
     | 
| 1202 | 
         
            +
                    arr = np.array(image)
         
     | 
| 1203 | 
         
            +
                    h = arr.shape[0]
         
     | 
| 1204 | 
         
            +
                    w = arr.shape[1]
         
     | 
| 1205 | 
         
            +
                    crop_y = (h - size) // 2
         
     | 
| 1206 | 
         
            +
                    crop_x = (w - size) // 2
         
     | 
| 1207 | 
         
            +
                    pil_image = image.crop([crop_x, crop_y, crop_x+size, crop_y+size])
         
     | 
| 1208 | 
         
            +
                    if output_type == "pil_image":
         
     | 
| 1209 | 
         
            +
                        return pil_image
         
     | 
| 1210 | 
         
            +
                    
         
     | 
| 1211 | 
         
            +
                    image_np = arr[crop_y : crop_y + size, crop_x : crop_x + size]
         
     | 
| 1212 | 
         
            +
                    hidden_h = h // 16
         
     | 
| 1213 | 
         
            +
                    hidden_w = w // 16
         
     | 
| 1214 | 
         
            +
                    hidden_size = size // 16
         
     | 
| 1215 | 
         
            +
                    img_ids = torch.zeros(hidden_h, hidden_w, 3)
         
     | 
| 1216 | 
         
            +
                    
         
     | 
| 1217 | 
         
            +
                    img_ids[..., 1] = img_ids[..., 1] + torch.arange(hidden_h)[:, None]
         
     | 
| 1218 | 
         
            +
                    img_ids[..., 2] = img_ids[..., 2] + torch.arange(hidden_w)[None, :]
         
     | 
| 1219 | 
         
            +
                    crop_y = (hidden_h - hidden_size) // 2
         
     | 
| 1220 | 
         
            +
                    crop_x = (hidden_w - hidden_size) // 2
         
     | 
| 1221 | 
         
            +
                    img_ids = img_ids[crop_y : crop_y + hidden_size, crop_x : crop_x + hidden_size]
         
     | 
| 1222 | 
         
            +
                    img_ids = rearrange(img_ids, "h w c -> (h w) c")
         
     | 
| 1223 | 
         
            +
             
     | 
| 1224 | 
         
            +
                    image_tensor = torchvision.transforms.functional.to_tensor(image_np)
         
     | 
| 1225 | 
         
            +
                    if Norm:
         
     | 
| 1226 | 
         
            +
                        image_tensor = torchvision.transforms.functional.normalize(image_tensor, 
         
     | 
| 1227 | 
         
            +
                        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         
     | 
| 1228 | 
         
            +
                    return pil_image, image_tensor, img_ids
         
     | 
| 1229 | 
         
            +
             
     | 
| 1230 | 
         
            +
                def process_image_aspectratio(self, image, size):
         
     | 
| 1231 | 
         
            +
                    w, h = image.size
         
     | 
| 1232 | 
         
            +
                    t_w, t_h = size
         
     | 
| 1233 | 
         
            +
                    resize_r = max(float(t_w)/w, float(t_h)/h)
         
     | 
| 1234 | 
         
            +
                    resize_size = (int(resize_r * h), int(resize_r * w))
         
     | 
| 1235 | 
         
            +
                    image = torchvision.transforms.functional.resize(
         
     | 
| 1236 | 
         
            +
                        image, resize_size, interpolation=transforms.InterpolationMode.BICUBIC
         
     | 
| 1237 | 
         
            +
                    )
         
     | 
| 1238 | 
         
            +
                    pil_image = torchvision.transforms.functional.center_crop(
         
     | 
| 1239 | 
         
            +
                        image, (t_h, t_w)
         
     | 
| 1240 | 
         
            +
                    )
         
     | 
| 1241 | 
         
            +
                    hidden_h = t_h // 16
         
     | 
| 1242 | 
         
            +
                    hidden_w = t_w // 16
         
     | 
| 1243 | 
         
            +
                    img_ids = torch.zeros(hidden_h, hidden_w, 3)
         
     | 
| 1244 | 
         
            +
                    
         
     | 
| 1245 | 
         
            +
                    img_ids[..., 1] = img_ids[..., 1] + torch.arange(hidden_h)[:, None]
         
     | 
| 1246 | 
         
            +
                    img_ids[..., 2] = img_ids[..., 2] + torch.arange(hidden_w)[None, :]
         
     | 
| 1247 | 
         
            +
                    img_ids = rearrange(img_ids, "h w c -> (h w) c")
         
     | 
| 1248 | 
         
            +
                    image_tensor = torchvision.transforms.functional.to_tensor(pil_image)
         
     | 
| 1249 | 
         
            +
                    image_tensor = torchvision.transforms.functional.normalize(image_tensor, 
         
     | 
| 1250 | 
         
            +
                        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         
     | 
| 1251 | 
         
            +
                    return pil_image, image_tensor, img_ids
         
     | 
| 1252 | 
         
            +
                
         
     | 
| 1253 | 
         
            +
                def compute_vae_encodings(self, pixel_values, with_ids=True, time=0):
         
     | 
| 1254 | 
         
            +
                    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
         
     | 
| 1255 | 
         
            +
                    pixel_values = pixel_values.to(self.vae.device, dtype=self.vae.dtype)
         
     | 
| 1256 | 
         
            +
                    with torch.no_grad():
         
     | 
| 1257 | 
         
            +
                        model_input = self.vae.encode(pixel_values).latent_dist.sample()
         
     | 
| 1258 | 
         
            +
                        if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None:
         
     | 
| 1259 | 
         
            +
                            model_input = model_input - self.vae.config.shift_factor
         
     | 
| 1260 | 
         
            +
                        if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None:
         
     | 
| 1261 | 
         
            +
                            model_input = model_input * self.vae.config.scaling_factor
         
     | 
| 1262 | 
         
            +
                    # patch for transformer
         
     | 
| 1263 | 
         
            +
                    bs, c, h, w = model_input.shape
         
     | 
| 1264 | 
         
            +
                    model_input = rearrange(model_input, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         
     | 
| 1265 | 
         
            +
                    if with_ids:
         
     | 
| 1266 | 
         
            +
                        img_ids = torch.zeros(h // 2, w // 2, 3)
         
     | 
| 1267 | 
         
            +
                        img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
         
     | 
| 1268 | 
         
            +
                        img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
         
     | 
| 1269 | 
         
            +
                        img_ids[..., 0] = time
         
     | 
| 1270 | 
         
            +
                        img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
         
     | 
| 1271 | 
         
            +
                        return model_input, img_ids
         
     | 
| 1272 | 
         
            +
                    else:
         
     | 
| 1273 | 
         
            +
                        return model_input
         
     | 
| 1274 | 
         
            +
             
     | 
| 1275 | 
         
            +
                def generate_image(
         
     | 
| 1276 | 
         
            +
                        self, 
         
     | 
| 1277 | 
         
            +
                        cond,
         
     | 
| 1278 | 
         
            +
                        height, 
         
     | 
| 1279 | 
         
            +
                        width, 
         
     | 
| 1280 | 
         
            +
                        num_steps, 
         
     | 
| 1281 | 
         
            +
                        seed, 
         
     | 
| 1282 | 
         
            +
                        no_both_cond=None, 
         
     | 
| 1283 | 
         
            +
                        no_txt_cond=None,
         
     | 
| 1284 | 
         
            +
                        img_cfg=1.0,
         
     | 
| 1285 | 
         
            +
                        txt_cfg=1.0,
         
     | 
| 1286 | 
         
            +
                        output_type="pil"
         
     | 
| 1287 | 
         
            +
                    ):
         
     | 
| 1288 | 
         
            +
                    txt = cond["txt"]
         
     | 
| 1289 | 
         
            +
                    bs = len(txt)
         
     | 
| 1290 | 
         
            +
                    channel = self.vae.config.latent_channels
         
     | 
| 1291 | 
         
            +
                    height = 16 * (height // 16)
         
     | 
| 1292 | 
         
            +
                    width = 16 * (width // 16)
         
     | 
| 1293 | 
         
            +
                    torch_device = next(self.backbone.parameters()).device
         
     | 
| 1294 | 
         
            +
                    x = get_noise(
         
     | 
| 1295 | 
         
            +
                        bs,
         
     | 
| 1296 | 
         
            +
                        channel,
         
     | 
| 1297 | 
         
            +
                        height,
         
     | 
| 1298 | 
         
            +
                        width,
         
     | 
| 1299 | 
         
            +
                        device=torch_device,
         
     | 
| 1300 | 
         
            +
                        dtype=torch.bfloat16,
         
     | 
| 1301 | 
         
            +
                        seed=seed,
         
     | 
| 1302 | 
         
            +
                    )
         
     | 
| 1303 | 
         
            +
                    # prepare inputs
         
     | 
| 1304 | 
         
            +
                    img = x
         
     | 
| 1305 | 
         
            +
                    bs, c, h, w = img.shape
         
     | 
| 1306 | 
         
            +
             
     | 
| 1307 | 
         
            +
                    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
         
     | 
| 1308 | 
         
            +
                    if img.shape[0] == 1 and bs > 1:
         
     | 
| 1309 | 
         
            +
                        img = repeat(img, "1 ... -> bs ...", bs=bs)
         
     | 
| 1310 | 
         
            +
             
     | 
| 1311 | 
         
            +
                    img_ids = torch.zeros(h // 2, w // 2, 3)
         
     | 
| 1312 | 
         
            +
                    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
         
     | 
| 1313 | 
         
            +
                    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
         
     | 
| 1314 | 
         
            +
                    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(img.device)
         
     | 
| 1315 | 
         
            +
             
     | 
| 1316 | 
         
            +
                    if "vae_pixel_values" in cond:
         
     | 
| 1317 | 
         
            +
                        img_vae_cond, cond_ids = self.compute_vae_encodings(
         
     | 
| 1318 | 
         
            +
                            pixel_values=cond["vae_pixel_values"], with_ids=True, time=1.0)
         
     | 
| 1319 | 
         
            +
                        cond_ids = cond_ids.to(img.device)
         
     | 
| 1320 | 
         
            +
             
     | 
| 1321 | 
         
            +
                    if txt.shape[0] == 1 and bs > 1:
         
     | 
| 1322 | 
         
            +
                        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
         
     | 
| 1323 | 
         
            +
                    txt_ids = torch.zeros(bs, txt.shape[1], 3).to(img.device)
         
     | 
| 1324 | 
         
            +
             
     | 
| 1325 | 
         
            +
                    timesteps = self.get_schedule(
         
     | 
| 1326 | 
         
            +
                        num_steps, img.shape[1], shift=self.config.timestep_shift,
         
     | 
| 1327 | 
         
            +
                        base_shift=self.config.base_shift, max_shift=self.config.max_shift)
         
     | 
| 1328 | 
         
            +
                    no_both_txt = no_both_cond["txt"]
         
     | 
| 1329 | 
         
            +
                    if no_txt_cond is not None:
         
     | 
| 1330 | 
         
            +
                        no_txt_txt = no_txt_cond["txt"]
         
     | 
| 1331 | 
         
            +
                        x = self.edit_denoise(img, img_ids, 
         
     | 
| 1332 | 
         
            +
                                              txt, txt_ids, 
         
     | 
| 1333 | 
         
            +
                                              no_txt_txt, 
         
     | 
| 1334 | 
         
            +
                                              no_both_txt,
         
     | 
| 1335 | 
         
            +
                                              img_vae_cond, cond_ids.to(img.device),
         
     | 
| 1336 | 
         
            +
                                              timesteps=timesteps, 
         
     | 
| 1337 | 
         
            +
                                              img_cfg=img_cfg, txt_cfg=txt_cfg)
         
     | 
| 1338 | 
         
            +
                    else:
         
     | 
| 1339 | 
         
            +
                        x = self.denoise(img, img_ids, txt, txt_ids, 
         
     | 
| 1340 | 
         
            +
                                         timesteps=timesteps, cfg=txt_cfg, 
         
     | 
| 1341 | 
         
            +
                                         neg_txt=no_both_txt)
         
     | 
| 1342 | 
         
            +
                    x = unpack(x.float(), height, width)
         
     | 
| 1343 | 
         
            +
             
     | 
| 1344 | 
         
            +
                    with torch.autocast(device_type=torch_device.type, dtype=torch.float32):
         
     | 
| 1345 | 
         
            +
                        if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None:
         
     | 
| 1346 | 
         
            +
                            x = x / self.vae.config.scaling_factor
         
     | 
| 1347 | 
         
            +
                        if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None:
         
     | 
| 1348 | 
         
            +
                            x = x + self.vae.config.shift_factor
         
     | 
| 1349 | 
         
            +
                        x = self.vae.decode(x, return_dict=False)[0]
         
     | 
| 1350 | 
         
            +
                    # bring into PIL format and save
         
     | 
| 1351 | 
         
            +
                    x = x.clamp(-1, 1)
         
     | 
| 1352 | 
         
            +
                    x = rearrange(x, "b c h w -> b h w c")
         
     | 
| 1353 | 
         
            +
                    x = (127.5 * (x + 1.0)).cpu().byte().numpy()
         
     | 
| 1354 | 
         
            +
                    if output_type == "np":
         
     | 
| 1355 | 
         
            +
                        return x
         
     | 
| 1356 | 
         
            +
                    images = []
         
     | 
| 1357 | 
         
            +
                    for i in range(bs):
         
     | 
| 1358 | 
         
            +
                        img = Image.fromarray(x[i])
         
     | 
| 1359 | 
         
            +
                        images.append(img)
         
     | 
| 1360 | 
         
            +
                    return images
         
     | 
| 1361 | 
         
            +
             
     | 
| 1362 | 
         
            +
             
     | 
| 1363 | 
         
            +
                def get_schedule(self,
         
     | 
| 1364 | 
         
            +
                    num_steps: int,
         
     | 
| 1365 | 
         
            +
                    image_seq_len: int,
         
     | 
| 1366 | 
         
            +
                    base_shift: float = 0.5,
         
     | 
| 1367 | 
         
            +
                    max_shift: float = 1.15,
         
     | 
| 1368 | 
         
            +
                    shift: bool = True,
         
     | 
| 1369 | 
         
            +
                ) -> list[float]:
         
     | 
| 1370 | 
         
            +
                    # extra step for zero
         
     | 
| 1371 | 
         
            +
                    timesteps = torch.linspace(1, 0, num_steps + 1)
         
     | 
| 1372 | 
         
            +
                    # shifting the schedule to favor high timesteps for higher signal images
         
     | 
| 1373 | 
         
            +
                    if shift:
         
     | 
| 1374 | 
         
            +
                        # eastimate mu based on linear estimation between two points
         
     | 
| 1375 | 
         
            +
                        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
         
     | 
| 1376 | 
         
            +
                        timesteps = time_shift(mu, 1.0, timesteps)
         
     | 
| 1377 | 
         
            +
             
     | 
| 1378 | 
         
            +
                    return timesteps.tolist()
         
     | 
| 1379 | 
         
            +
             
     | 
| 1380 | 
         
            +
                def denoise(self, 
         
     | 
| 1381 | 
         
            +
                            input_img: Tensor,
         
     | 
| 1382 | 
         
            +
                            img_ids: Tensor,
         
     | 
| 1383 | 
         
            +
                            txt: Tensor,
         
     | 
| 1384 | 
         
            +
                            txt_ids: Tensor,
         
     | 
| 1385 | 
         
            +
                            # sampling parameters
         
     | 
| 1386 | 
         
            +
                            timesteps: list[float],
         
     | 
| 1387 | 
         
            +
                            cfg: float = 1.0,
         
     | 
| 1388 | 
         
            +
                            neg_txt = None):
         
     | 
| 1389 | 
         
            +
                    bs = input_img.shape[0]
         
     | 
| 1390 | 
         
            +
                    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
         
     | 
| 1391 | 
         
            +
                        t_vec = torch.full((bs,), t_curr, dtype=input_img.dtype, device=input_img.device)
         
     | 
| 1392 | 
         
            +
                        txt_ids = torch.zeros(bs, txt.shape[1], 3).to(txt.device)
         
     | 
| 1393 | 
         
            +
                        cond_eps = self.backbone(
         
     | 
| 1394 | 
         
            +
                            img=input_img,
         
     | 
| 1395 | 
         
            +
                            img_ids=img_ids,
         
     | 
| 1396 | 
         
            +
                            txt=txt,
         
     | 
| 1397 | 
         
            +
                            txt_ids=txt_ids,
         
     | 
| 1398 | 
         
            +
                            timesteps=t_vec,
         
     | 
| 1399 | 
         
            +
                        )
         
     | 
| 1400 | 
         
            +
                        txt_ids = torch.zeros(bs, neg_txt.shape[1], 3).to(neg_txt.device)
         
     | 
| 1401 | 
         
            +
                        uncond_eps = self.backbone(
         
     | 
| 1402 | 
         
            +
                            img=input_img,
         
     | 
| 1403 | 
         
            +
                            img_ids=img_ids,
         
     | 
| 1404 | 
         
            +
                            txt=neg_txt,
         
     | 
| 1405 | 
         
            +
                            txt_ids=txt_ids,
         
     | 
| 1406 | 
         
            +
                            timesteps=t_vec,
         
     | 
| 1407 | 
         
            +
                        )
         
     | 
| 1408 | 
         
            +
                        pred = uncond_eps + cfg * (cond_eps - uncond_eps)
         
     | 
| 1409 | 
         
            +
                        input_img = input_img + (t_prev - t_curr) * pred
         
     | 
| 1410 | 
         
            +
                    return input_img
         
     | 
| 1411 | 
         
            +
                
         
     | 
| 1412 | 
         
            +
                def edit_denoise(self, 
         
     | 
| 1413 | 
         
            +
                            input_img: Tensor,
         
     | 
| 1414 | 
         
            +
                            img_ids: Tensor,
         
     | 
| 1415 | 
         
            +
                            txt: Tensor,
         
     | 
| 1416 | 
         
            +
                            txt_ids: Tensor,
         
     | 
| 1417 | 
         
            +
                            no_txt_txt: Tensor,
         
     | 
| 1418 | 
         
            +
                            no_both_txt: Tensor,
         
     | 
| 1419 | 
         
            +
                            img_cond,
         
     | 
| 1420 | 
         
            +
                            cond_img_ids,
         
     | 
| 1421 | 
         
            +
                            # sampling parameters
         
     | 
| 1422 | 
         
            +
                            timesteps: list[float],
         
     | 
| 1423 | 
         
            +
                            img_cfg: float = 1.0,
         
     | 
| 1424 | 
         
            +
                            txt_cfg: float = 1.0,):
         
     | 
| 1425 | 
         
            +
                    bs = input_img.shape[0]
         
     | 
| 1426 | 
         
            +
                    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
         
     | 
| 1427 | 
         
            +
                        t_vec = torch.full((bs * 1,), t_curr, dtype=input_img.dtype, device=input_img.device)
         
     | 
| 1428 | 
         
            +
                        txt_ids = torch.zeros(bs, txt.shape[1], 3).to(txt.device)
         
     | 
| 1429 | 
         
            +
                        cond_eps = self.backbone(
         
     | 
| 1430 | 
         
            +
                            img=input_img,
         
     | 
| 1431 | 
         
            +
                            img_ids=img_ids,
         
     | 
| 1432 | 
         
            +
                            txt=txt,
         
     | 
| 1433 | 
         
            +
                            txt_ids=txt_ids,
         
     | 
| 1434 | 
         
            +
                            timesteps=t_vec,
         
     | 
| 1435 | 
         
            +
                            cond_img=img_cond,
         
     | 
| 1436 | 
         
            +
                            cond_img_ids=cond_img_ids,
         
     | 
| 1437 | 
         
            +
                        )
         
     | 
| 1438 | 
         
            +
                        txt_ids = torch.zeros(bs, no_both_txt.shape[1], 3).to(no_both_txt.device)
         
     | 
| 1439 | 
         
            +
                        no_both_eps = self.backbone(
         
     | 
| 1440 | 
         
            +
                            img=input_img,
         
     | 
| 1441 | 
         
            +
                            img_ids=img_ids,
         
     | 
| 1442 | 
         
            +
                            txt=no_both_txt,
         
     | 
| 1443 | 
         
            +
                            txt_ids=txt_ids,
         
     | 
| 1444 | 
         
            +
                            timesteps=t_vec,
         
     | 
| 1445 | 
         
            +
                        )
         
     | 
| 1446 | 
         
            +
                        txt_ids = torch.zeros(bs, no_txt_txt.shape[1], 3).to(no_txt_txt.device)
         
     | 
| 1447 | 
         
            +
                        no_txt_eps = self.backbone(
         
     | 
| 1448 | 
         
            +
                            img=input_img,
         
     | 
| 1449 | 
         
            +
                            img_ids=img_ids,
         
     | 
| 1450 | 
         
            +
                            txt=no_txt_txt,
         
     | 
| 1451 | 
         
            +
                            txt_ids=txt_ids,
         
     | 
| 1452 | 
         
            +
                            timesteps=t_vec,
         
     | 
| 1453 | 
         
            +
                            cond_img=img_cond,
         
     | 
| 1454 | 
         
            +
                            cond_img_ids=cond_img_ids,
         
     | 
| 1455 | 
         
            +
                        )
         
     | 
| 1456 | 
         
            +
                        pred = no_both_eps 
         
     | 
| 1457 | 
         
            +
                        pred += img_cfg * (no_txt_eps - no_both_eps) 
         
     | 
| 1458 | 
         
            +
                        pred += txt_cfg * (cond_eps - no_txt_eps)
         
     | 
| 1459 | 
         
            +
                        input_img = input_img + (t_prev - t_curr) * pred
         
     | 
| 1460 | 
         
            +
                    return input_img
         
     | 
| 1461 | 
         
            +
             
     | 
    	
        preprocessor_config.json
    ADDED
    
    | 
         @@ -0,0 +1,32 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "crop_size": {
         
     | 
| 3 | 
         
            +
                "height": -1,
         
     | 
| 4 | 
         
            +
                "width": -1
         
     | 
| 5 | 
         
            +
              },
         
     | 
| 6 | 
         
            +
              "do_center_crop": false,
         
     | 
| 7 | 
         
            +
              "do_convert_rgb": true,
         
     | 
| 8 | 
         
            +
              "do_normalize": true,
         
     | 
| 9 | 
         
            +
              "do_rescale": true,
         
     | 
| 10 | 
         
            +
              "do_resize": true,
         
     | 
| 11 | 
         
            +
              "hidden_stride": 2,
         
     | 
| 12 | 
         
            +
              "image_mean": [
         
     | 
| 13 | 
         
            +
                0.48145466,
         
     | 
| 14 | 
         
            +
                0.4578275,
         
     | 
| 15 | 
         
            +
                0.40821073
         
     | 
| 16 | 
         
            +
              ],
         
     | 
| 17 | 
         
            +
              "image_processor_type": "CLIPImageProcessor",
         
     | 
| 18 | 
         
            +
              "image_std": [
         
     | 
| 19 | 
         
            +
                0.26862954,
         
     | 
| 20 | 
         
            +
                0.26130258,
         
     | 
| 21 | 
         
            +
                0.27577711
         
     | 
| 22 | 
         
            +
              ],
         
     | 
| 23 | 
         
            +
              "max_pixels": 2408448,
         
     | 
| 24 | 
         
            +
              "min_pixels": 200704,
         
     | 
| 25 | 
         
            +
              "patch_size": 14,
         
     | 
| 26 | 
         
            +
              "resample": 3,
         
     | 
| 27 | 
         
            +
              "rescale_factor": 0.00392156862745098,
         
     | 
| 28 | 
         
            +
              "size": {
         
     | 
| 29 | 
         
            +
                "shortest_edge": -1
         
     | 
| 30 | 
         
            +
              },
         
     | 
| 31 | 
         
            +
              "temporal_patch_size": 1
         
     | 
| 32 | 
         
            +
            }
         
     | 
    	
        special_tokens_map.json
    ADDED
    
    | 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "additional_special_tokens": [
         
     | 
| 3 | 
         
            +
                "<|im_start|>",
         
     | 
| 4 | 
         
            +
                "<|im_end|>",
         
     | 
| 5 | 
         
            +
                "<|object_ref_start|>",
         
     | 
| 6 | 
         
            +
                "<|object_ref_end|>",
         
     | 
| 7 | 
         
            +
                "<|box_start|>",
         
     | 
| 8 | 
         
            +
                "<|box_end|>",
         
     | 
| 9 | 
         
            +
                "<|quad_start|>",
         
     | 
| 10 | 
         
            +
                "<|quad_end|>",
         
     | 
| 11 | 
         
            +
                "<|vision_start|>",
         
     | 
| 12 | 
         
            +
                "<|vision_end|>",
         
     | 
| 13 | 
         
            +
                "<|vision_pad|>",
         
     | 
| 14 | 
         
            +
                "<|image_pad|>",
         
     | 
| 15 | 
         
            +
                "<|video_pad|>"
         
     | 
| 16 | 
         
            +
              ],
         
     | 
| 17 | 
         
            +
              "eos_token": {
         
     | 
| 18 | 
         
            +
                "content": "<|im_end|>",
         
     | 
| 19 | 
         
            +
                "lstrip": false,
         
     | 
| 20 | 
         
            +
                "normalized": false,
         
     | 
| 21 | 
         
            +
                "rstrip": false,
         
     | 
| 22 | 
         
            +
                "single_word": false
         
     | 
| 23 | 
         
            +
              },
         
     | 
| 24 | 
         
            +
              "pad_token": {
         
     | 
| 25 | 
         
            +
                "content": "<|endoftext|>",
         
     | 
| 26 | 
         
            +
                "lstrip": false,
         
     | 
| 27 | 
         
            +
                "normalized": false,
         
     | 
| 28 | 
         
            +
                "rstrip": false,
         
     | 
| 29 | 
         
            +
                "single_word": false
         
     | 
| 30 | 
         
            +
              }
         
     | 
| 31 | 
         
            +
            }
         
     | 
    	
        tokenizer.json
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
         
     | 
| 3 | 
         
            +
            size 11422654
         
     | 
    	
        tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1,240 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "add_bos_token": false,
         
     | 
| 3 | 
         
            +
              "add_prefix_space": false,
         
     | 
| 4 | 
         
            +
              "added_tokens_decoder": {
         
     | 
| 5 | 
         
            +
                "151643": {
         
     | 
| 6 | 
         
            +
                  "content": "<|endoftext|>",
         
     | 
| 7 | 
         
            +
                  "lstrip": false,
         
     | 
| 8 | 
         
            +
                  "normalized": false,
         
     | 
| 9 | 
         
            +
                  "rstrip": false,
         
     | 
| 10 | 
         
            +
                  "single_word": false,
         
     | 
| 11 | 
         
            +
                  "special": true
         
     | 
| 12 | 
         
            +
                },
         
     | 
| 13 | 
         
            +
                "151644": {
         
     | 
| 14 | 
         
            +
                  "content": "<|im_start|>",
         
     | 
| 15 | 
         
            +
                  "lstrip": false,
         
     | 
| 16 | 
         
            +
                  "normalized": false,
         
     | 
| 17 | 
         
            +
                  "rstrip": false,
         
     | 
| 18 | 
         
            +
                  "single_word": false,
         
     | 
| 19 | 
         
            +
                  "special": true
         
     | 
| 20 | 
         
            +
                },
         
     | 
| 21 | 
         
            +
                "151645": {
         
     | 
| 22 | 
         
            +
                  "content": "<|im_end|>",
         
     | 
| 23 | 
         
            +
                  "lstrip": false,
         
     | 
| 24 | 
         
            +
                  "normalized": false,
         
     | 
| 25 | 
         
            +
                  "rstrip": false,
         
     | 
| 26 | 
         
            +
                  "single_word": false,
         
     | 
| 27 | 
         
            +
                  "special": true
         
     | 
| 28 | 
         
            +
                },
         
     | 
| 29 | 
         
            +
                "151646": {
         
     | 
| 30 | 
         
            +
                  "content": "<|object_ref_start|>",
         
     | 
| 31 | 
         
            +
                  "lstrip": false,
         
     | 
| 32 | 
         
            +
                  "normalized": false,
         
     | 
| 33 | 
         
            +
                  "rstrip": false,
         
     | 
| 34 | 
         
            +
                  "single_word": false,
         
     | 
| 35 | 
         
            +
                  "special": true
         
     | 
| 36 | 
         
            +
                },
         
     | 
| 37 | 
         
            +
                "151647": {
         
     | 
| 38 | 
         
            +
                  "content": "<|object_ref_end|>",
         
     | 
| 39 | 
         
            +
                  "lstrip": false,
         
     | 
| 40 | 
         
            +
                  "normalized": false,
         
     | 
| 41 | 
         
            +
                  "rstrip": false,
         
     | 
| 42 | 
         
            +
                  "single_word": false,
         
     | 
| 43 | 
         
            +
                  "special": true
         
     | 
| 44 | 
         
            +
                },
         
     | 
| 45 | 
         
            +
                "151648": {
         
     | 
| 46 | 
         
            +
                  "content": "<|box_start|>",
         
     | 
| 47 | 
         
            +
                  "lstrip": false,
         
     | 
| 48 | 
         
            +
                  "normalized": false,
         
     | 
| 49 | 
         
            +
                  "rstrip": false,
         
     | 
| 50 | 
         
            +
                  "single_word": false,
         
     | 
| 51 | 
         
            +
                  "special": true
         
     | 
| 52 | 
         
            +
                },
         
     | 
| 53 | 
         
            +
                "151649": {
         
     | 
| 54 | 
         
            +
                  "content": "<|box_end|>",
         
     | 
| 55 | 
         
            +
                  "lstrip": false,
         
     | 
| 56 | 
         
            +
                  "normalized": false,
         
     | 
| 57 | 
         
            +
                  "rstrip": false,
         
     | 
| 58 | 
         
            +
                  "single_word": false,
         
     | 
| 59 | 
         
            +
                  "special": true
         
     | 
| 60 | 
         
            +
                },
         
     | 
| 61 | 
         
            +
                "151650": {
         
     | 
| 62 | 
         
            +
                  "content": "<|quad_start|>",
         
     | 
| 63 | 
         
            +
                  "lstrip": false,
         
     | 
| 64 | 
         
            +
                  "normalized": false,
         
     | 
| 65 | 
         
            +
                  "rstrip": false,
         
     | 
| 66 | 
         
            +
                  "single_word": false,
         
     | 
| 67 | 
         
            +
                  "special": true
         
     | 
| 68 | 
         
            +
                },
         
     | 
| 69 | 
         
            +
                "151651": {
         
     | 
| 70 | 
         
            +
                  "content": "<|quad_end|>",
         
     | 
| 71 | 
         
            +
                  "lstrip": false,
         
     | 
| 72 | 
         
            +
                  "normalized": false,
         
     | 
| 73 | 
         
            +
                  "rstrip": false,
         
     | 
| 74 | 
         
            +
                  "single_word": false,
         
     | 
| 75 | 
         
            +
                  "special": true
         
     | 
| 76 | 
         
            +
                },
         
     | 
| 77 | 
         
            +
                "151652": {
         
     | 
| 78 | 
         
            +
                  "content": "<|vision_start|>",
         
     | 
| 79 | 
         
            +
                  "lstrip": false,
         
     | 
| 80 | 
         
            +
                  "normalized": false,
         
     | 
| 81 | 
         
            +
                  "rstrip": false,
         
     | 
| 82 | 
         
            +
                  "single_word": false,
         
     | 
| 83 | 
         
            +
                  "special": true
         
     | 
| 84 | 
         
            +
                },
         
     | 
| 85 | 
         
            +
                "151653": {
         
     | 
| 86 | 
         
            +
                  "content": "<|vision_end|>",
         
     | 
| 87 | 
         
            +
                  "lstrip": false,
         
     | 
| 88 | 
         
            +
                  "normalized": false,
         
     | 
| 89 | 
         
            +
                  "rstrip": false,
         
     | 
| 90 | 
         
            +
                  "single_word": false,
         
     | 
| 91 | 
         
            +
                  "special": true
         
     | 
| 92 | 
         
            +
                },
         
     | 
| 93 | 
         
            +
                "151654": {
         
     | 
| 94 | 
         
            +
                  "content": "<|vision_pad|>",
         
     | 
| 95 | 
         
            +
                  "lstrip": false,
         
     | 
| 96 | 
         
            +
                  "normalized": false,
         
     | 
| 97 | 
         
            +
                  "rstrip": false,
         
     | 
| 98 | 
         
            +
                  "single_word": false,
         
     | 
| 99 | 
         
            +
                  "special": true
         
     | 
| 100 | 
         
            +
                },
         
     | 
| 101 | 
         
            +
                "151655": {
         
     | 
| 102 | 
         
            +
                  "content": "<|image_pad|>",
         
     | 
| 103 | 
         
            +
                  "lstrip": false,
         
     | 
| 104 | 
         
            +
                  "normalized": false,
         
     | 
| 105 | 
         
            +
                  "rstrip": false,
         
     | 
| 106 | 
         
            +
                  "single_word": false,
         
     | 
| 107 | 
         
            +
                  "special": true
         
     | 
| 108 | 
         
            +
                },
         
     | 
| 109 | 
         
            +
                "151656": {
         
     | 
| 110 | 
         
            +
                  "content": "<|video_pad|>",
         
     | 
| 111 | 
         
            +
                  "lstrip": false,
         
     | 
| 112 | 
         
            +
                  "normalized": false,
         
     | 
| 113 | 
         
            +
                  "rstrip": false,
         
     | 
| 114 | 
         
            +
                  "single_word": false,
         
     | 
| 115 | 
         
            +
                  "special": true
         
     | 
| 116 | 
         
            +
                },
         
     | 
| 117 | 
         
            +
                "151657": {
         
     | 
| 118 | 
         
            +
                  "content": "<tool_call>",
         
     | 
| 119 | 
         
            +
                  "lstrip": false,
         
     | 
| 120 | 
         
            +
                  "normalized": false,
         
     | 
| 121 | 
         
            +
                  "rstrip": false,
         
     | 
| 122 | 
         
            +
                  "single_word": false,
         
     | 
| 123 | 
         
            +
                  "special": false
         
     | 
| 124 | 
         
            +
                },
         
     | 
| 125 | 
         
            +
                "151658": {
         
     | 
| 126 | 
         
            +
                  "content": "</tool_call>",
         
     | 
| 127 | 
         
            +
                  "lstrip": false,
         
     | 
| 128 | 
         
            +
                  "normalized": false,
         
     | 
| 129 | 
         
            +
                  "rstrip": false,
         
     | 
| 130 | 
         
            +
                  "single_word": false,
         
     | 
| 131 | 
         
            +
                  "special": false
         
     | 
| 132 | 
         
            +
                },
         
     | 
| 133 | 
         
            +
                "151659": {
         
     | 
| 134 | 
         
            +
                  "content": "<|fim_prefix|>",
         
     | 
| 135 | 
         
            +
                  "lstrip": false,
         
     | 
| 136 | 
         
            +
                  "normalized": false,
         
     | 
| 137 | 
         
            +
                  "rstrip": false,
         
     | 
| 138 | 
         
            +
                  "single_word": false,
         
     | 
| 139 | 
         
            +
                  "special": false
         
     | 
| 140 | 
         
            +
                },
         
     | 
| 141 | 
         
            +
                "151660": {
         
     | 
| 142 | 
         
            +
                  "content": "<|fim_middle|>",
         
     | 
| 143 | 
         
            +
                  "lstrip": false,
         
     | 
| 144 | 
         
            +
                  "normalized": false,
         
     | 
| 145 | 
         
            +
                  "rstrip": false,
         
     | 
| 146 | 
         
            +
                  "single_word": false,
         
     | 
| 147 | 
         
            +
                  "special": false
         
     | 
| 148 | 
         
            +
                },
         
     | 
| 149 | 
         
            +
                "151661": {
         
     | 
| 150 | 
         
            +
                  "content": "<|fim_suffix|>",
         
     | 
| 151 | 
         
            +
                  "lstrip": false,
         
     | 
| 152 | 
         
            +
                  "normalized": false,
         
     | 
| 153 | 
         
            +
                  "rstrip": false,
         
     | 
| 154 | 
         
            +
                  "single_word": false,
         
     | 
| 155 | 
         
            +
                  "special": false
         
     | 
| 156 | 
         
            +
                },
         
     | 
| 157 | 
         
            +
                "151662": {
         
     | 
| 158 | 
         
            +
                  "content": "<|fim_pad|>",
         
     | 
| 159 | 
         
            +
                  "lstrip": false,
         
     | 
| 160 | 
         
            +
                  "normalized": false,
         
     | 
| 161 | 
         
            +
                  "rstrip": false,
         
     | 
| 162 | 
         
            +
                  "single_word": false,
         
     | 
| 163 | 
         
            +
                  "special": false
         
     | 
| 164 | 
         
            +
                },
         
     | 
| 165 | 
         
            +
                "151663": {
         
     | 
| 166 | 
         
            +
                  "content": "<|repo_name|>",
         
     | 
| 167 | 
         
            +
                  "lstrip": false,
         
     | 
| 168 | 
         
            +
                  "normalized": false,
         
     | 
| 169 | 
         
            +
                  "rstrip": false,
         
     | 
| 170 | 
         
            +
                  "single_word": false,
         
     | 
| 171 | 
         
            +
                  "special": false
         
     | 
| 172 | 
         
            +
                },
         
     | 
| 173 | 
         
            +
                "151664": {
         
     | 
| 174 | 
         
            +
                  "content": "<|file_sep|>",
         
     | 
| 175 | 
         
            +
                  "lstrip": false,
         
     | 
| 176 | 
         
            +
                  "normalized": false,
         
     | 
| 177 | 
         
            +
                  "rstrip": false,
         
     | 
| 178 | 
         
            +
                  "single_word": false,
         
     | 
| 179 | 
         
            +
                  "special": false
         
     | 
| 180 | 
         
            +
                },
         
     | 
| 181 | 
         
            +
                "151665": {
         
     | 
| 182 | 
         
            +
                  "content": "<tool_response>",
         
     | 
| 183 | 
         
            +
                  "lstrip": false,
         
     | 
| 184 | 
         
            +
                  "normalized": false,
         
     | 
| 185 | 
         
            +
                  "rstrip": false,
         
     | 
| 186 | 
         
            +
                  "single_word": false,
         
     | 
| 187 | 
         
            +
                  "special": false
         
     | 
| 188 | 
         
            +
                },
         
     | 
| 189 | 
         
            +
                "151666": {
         
     | 
| 190 | 
         
            +
                  "content": "</tool_response>",
         
     | 
| 191 | 
         
            +
                  "lstrip": false,
         
     | 
| 192 | 
         
            +
                  "normalized": false,
         
     | 
| 193 | 
         
            +
                  "rstrip": false,
         
     | 
| 194 | 
         
            +
                  "single_word": false,
         
     | 
| 195 | 
         
            +
                  "special": false
         
     | 
| 196 | 
         
            +
                },
         
     | 
| 197 | 
         
            +
                "151667": {
         
     | 
| 198 | 
         
            +
                  "content": "<think>",
         
     | 
| 199 | 
         
            +
                  "lstrip": false,
         
     | 
| 200 | 
         
            +
                  "normalized": false,
         
     | 
| 201 | 
         
            +
                  "rstrip": false,
         
     | 
| 202 | 
         
            +
                  "single_word": false,
         
     | 
| 203 | 
         
            +
                  "special": false
         
     | 
| 204 | 
         
            +
                },
         
     | 
| 205 | 
         
            +
                "151668": {
         
     | 
| 206 | 
         
            +
                  "content": "</think>",
         
     | 
| 207 | 
         
            +
                  "lstrip": false,
         
     | 
| 208 | 
         
            +
                  "normalized": false,
         
     | 
| 209 | 
         
            +
                  "rstrip": false,
         
     | 
| 210 | 
         
            +
                  "single_word": false,
         
     | 
| 211 | 
         
            +
                  "special": false
         
     | 
| 212 | 
         
            +
                }
         
     | 
| 213 | 
         
            +
              },
         
     | 
| 214 | 
         
            +
              "additional_special_tokens": [
         
     | 
| 215 | 
         
            +
                "<|im_start|>",
         
     | 
| 216 | 
         
            +
                "<|im_end|>",
         
     | 
| 217 | 
         
            +
                "<|object_ref_start|>",
         
     | 
| 218 | 
         
            +
                "<|object_ref_end|>",
         
     | 
| 219 | 
         
            +
                "<|box_start|>",
         
     | 
| 220 | 
         
            +
                "<|box_end|>",
         
     | 
| 221 | 
         
            +
                "<|quad_start|>",
         
     | 
| 222 | 
         
            +
                "<|quad_end|>",
         
     | 
| 223 | 
         
            +
                "<|vision_start|>",
         
     | 
| 224 | 
         
            +
                "<|vision_end|>",
         
     | 
| 225 | 
         
            +
                "<|vision_pad|>",
         
     | 
| 226 | 
         
            +
                "<|image_pad|>",
         
     | 
| 227 | 
         
            +
                "<|video_pad|>"
         
     | 
| 228 | 
         
            +
              ],
         
     | 
| 229 | 
         
            +
              "bos_token": null,
         
     | 
| 230 | 
         
            +
              "chat_template": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set content = message.content %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in message.content %}\n                {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n                {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.last or (not loop.last and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}",
         
     | 
| 231 | 
         
            +
              "clean_up_tokenization_spaces": false,
         
     | 
| 232 | 
         
            +
              "eos_token": "<|im_end|>",
         
     | 
| 233 | 
         
            +
              "errors": "replace",
         
     | 
| 234 | 
         
            +
              "extra_special_tokens": {},
         
     | 
| 235 | 
         
            +
              "model_max_length": 131072,
         
     | 
| 236 | 
         
            +
              "pad_token": "<|endoftext|>",
         
     | 
| 237 | 
         
            +
              "split_special_tokens": false,
         
     | 
| 238 | 
         
            +
              "tokenizer_class": "Qwen2Tokenizer",
         
     | 
| 239 | 
         
            +
              "unk_token": null
         
     | 
| 240 | 
         
            +
            }
         
     | 
    	
        vocab.json
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         |