Upload folder using huggingface_hub
Browse files- README.md +72 -0
- added_tokens.json +27 -0
- chat_template.jinja +54 -0
- config.json +35 -0
- configuration_dream.py +86 -0
- generation_config.json +16 -0
- generation_utils.py +463 -0
- merges.txt +0 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +346 -0
- modeling_dream.py +824 -0
- special_tokens_map.json +35 -0
- tokenization_dream.py +340 -0
- tokenizer_config.json +228 -0
- train_results.json +8 -0
- trainer_state.json +0 -0
- training_args.bin +3 -0
- vocab.json +0 -0
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            license: unknown
         | 
| 3 | 
            +
            base_model:
         | 
| 4 | 
            +
            - apple/DiffuCoder-7B-Instruct
         | 
| 5 | 
            +
            tags:
         | 
| 6 | 
            +
            - code
         | 
| 7 | 
            +
            - text-diffusion-model
         | 
| 8 | 
            +
            - diffusion large language model
         | 
| 9 | 
            +
            ---
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ### DiffuCoder-7B-cpGRPO
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            The DiffuCoder-7B-cpGRPO variant further refines DiffuCoder-Instruct with reinforcement learning via Coupled-GRPO.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Training recipe:
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            - Initialized from DiffuCoder-7B-Instruct, post-training with coupled-GRPO on 21K code data (1 epoch).
         | 
| 18 | 
            +
            - coupled-GRPO significantly improves DiffuCoder's performance on code generation benchmarks (+4.4\% on EvalPlus) and reduces reliance on AR bias during decoding.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            #### More details and usage examples:
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            - Paper: [DiffuCoder: Understanding and Improving Masked Diffusion Models for Code Generation](https://arxiv.org/abs/2506.20639)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            - GitHub: https://github.com/apple/ml-diffucoder
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            ```
         | 
| 28 | 
            +
            import torch
         | 
| 29 | 
            +
            from transformers import AutoModel, AutoTokenizer
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            model_path = "apple/DiffuCoder-7B-cpGRPO"
         | 
| 32 | 
            +
            model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
         | 
| 33 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
         | 
| 34 | 
            +
            model = model.to("cuda").eval()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            query = "Write a function to find the shared elements from the given two lists."
         | 
| 37 | 
            +
            prompt = f"""<|im_start|>system
         | 
| 38 | 
            +
            You are a helpful assistant.<|im_end|>
         | 
| 39 | 
            +
            <|im_start|>user
         | 
| 40 | 
            +
            {query.strip()}
         | 
| 41 | 
            +
            <|im_end|>
         | 
| 42 | 
            +
            <|im_start|>assistant
         | 
| 43 | 
            +
            """ ## following the template of qwen; you can also use apply_chat_template function
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            TOKEN_PER_STEP = 1 # diffusion timesteps * TOKEN_PER_STEP = total new tokens
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            inputs = tokenizer(prompt, return_tensors="pt")
         | 
| 48 | 
            +
            input_ids = inputs.input_ids.to(device="cuda")
         | 
| 49 | 
            +
            attention_mask = inputs.attention_mask.to(device="cuda")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            output = model.diffusion_generate(
         | 
| 52 | 
            +
                input_ids,
         | 
| 53 | 
            +
                attention_mask=attention_mask,
         | 
| 54 | 
            +
                max_new_tokens=256,
         | 
| 55 | 
            +
                output_history=True,
         | 
| 56 | 
            +
                return_dict_in_generate=True,
         | 
| 57 | 
            +
                steps=256//TOKEN_PER_STEP,
         | 
| 58 | 
            +
                temperature=0.4,
         | 
| 59 | 
            +
                top_p=0.95,
         | 
| 60 | 
            +
                alg="entropy",
         | 
| 61 | 
            +
                alg_temp=0.,
         | 
| 62 | 
            +
            )
         | 
| 63 | 
            +
            generations = [
         | 
| 64 | 
            +
                tokenizer.decode(g[len(p) :].tolist())
         | 
| 65 | 
            +
                for p, g in zip(input_ids, output.sequences)
         | 
| 66 | 
            +
            ]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            print(generations[0].split('<|dlm_pad|>')[0])
         | 
| 69 | 
            +
            ```
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            #### Acknowledgement
         | 
| 72 | 
            +
            To power this HuggingFace model release, we reuse [Dream](https://huggingface.co/Dream-org/Dream-v0-Base-7B)'s modeling architecture and generation utils.
         | 
    	
        added_tokens.json
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "</tool_call>": 151658,
         | 
| 3 | 
            +
              "<tool_call>": 151657,
         | 
| 4 | 
            +
              "<|beginoftext|>": 151665,
         | 
| 5 | 
            +
              "<|box_end|>": 151649,
         | 
| 6 | 
            +
              "<|box_start|>": 151648,
         | 
| 7 | 
            +
              "<|dlm_pad|>": 151667,
         | 
| 8 | 
            +
              "<|endoftext|>": 151643,
         | 
| 9 | 
            +
              "<|file_sep|>": 151664,
         | 
| 10 | 
            +
              "<|fim_middle|>": 151660,
         | 
| 11 | 
            +
              "<|fim_pad|>": 151662,
         | 
| 12 | 
            +
              "<|fim_prefix|>": 151659,
         | 
| 13 | 
            +
              "<|fim_suffix|>": 151661,
         | 
| 14 | 
            +
              "<|im_end|>": 151645,
         | 
| 15 | 
            +
              "<|im_start|>": 151644,
         | 
| 16 | 
            +
              "<|image_pad|>": 151655,
         | 
| 17 | 
            +
              "<|mask|>": 151666,
         | 
| 18 | 
            +
              "<|object_ref_end|>": 151647,
         | 
| 19 | 
            +
              "<|object_ref_start|>": 151646,
         | 
| 20 | 
            +
              "<|quad_end|>": 151651,
         | 
| 21 | 
            +
              "<|quad_start|>": 151650,
         | 
| 22 | 
            +
              "<|repo_name|>": 151663,
         | 
| 23 | 
            +
              "<|video_pad|>": 151656,
         | 
| 24 | 
            +
              "<|vision_end|>": 151653,
         | 
| 25 | 
            +
              "<|vision_pad|>": 151654,
         | 
| 26 | 
            +
              "<|vision_start|>": 151652
         | 
| 27 | 
            +
            }
         | 
    	
        chat_template.jinja
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {%- if tools %}
         | 
| 2 | 
            +
                {{- '<|im_start|>system\n' }}
         | 
| 3 | 
            +
                {%- if messages[0]['role'] == 'system' %}
         | 
| 4 | 
            +
                    {{- messages[0]['content'] }}
         | 
| 5 | 
            +
                {%- else %}
         | 
| 6 | 
            +
                    {{- 'You are a helpful assistant.' }}
         | 
| 7 | 
            +
                {%- endif %}
         | 
| 8 | 
            +
                {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
         | 
| 9 | 
            +
                {%- for tool in tools %}
         | 
| 10 | 
            +
                    {{- "\n" }}
         | 
| 11 | 
            +
                    {{- tool | tojson }}
         | 
| 12 | 
            +
                {%- endfor %}
         | 
| 13 | 
            +
                {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
         | 
| 14 | 
            +
            {%- else %}
         | 
| 15 | 
            +
                {%- if messages[0]['role'] == 'system' %}
         | 
| 16 | 
            +
                    {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
         | 
| 17 | 
            +
                {%- else %}
         | 
| 18 | 
            +
                    {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
         | 
| 19 | 
            +
                {%- endif %}
         | 
| 20 | 
            +
            {%- endif %}
         | 
| 21 | 
            +
            {%- for message in messages %}
         | 
| 22 | 
            +
                {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
         | 
| 23 | 
            +
                    {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
         | 
| 24 | 
            +
                {%- elif message.role == "assistant" %}
         | 
| 25 | 
            +
                    {{- '<|im_start|>' + message.role }}
         | 
| 26 | 
            +
                    {%- if message.content %}
         | 
| 27 | 
            +
                        {{- '\n' + message.content }}
         | 
| 28 | 
            +
                    {%- endif %}
         | 
| 29 | 
            +
                    {%- for tool_call in message.tool_calls %}
         | 
| 30 | 
            +
                        {%- if tool_call.function is defined %}
         | 
| 31 | 
            +
                            {%- set tool_call = tool_call.function %}
         | 
| 32 | 
            +
                        {%- endif %}
         | 
| 33 | 
            +
                        {{- '\n<tool_call>\n{"name": "' }}
         | 
| 34 | 
            +
                        {{- tool_call.name }}
         | 
| 35 | 
            +
                        {{- '", "arguments": ' }}
         | 
| 36 | 
            +
                        {{- tool_call.arguments | tojson }}
         | 
| 37 | 
            +
                        {{- '}\n</tool_call>' }}
         | 
| 38 | 
            +
                    {%- endfor %}
         | 
| 39 | 
            +
                    {{- '<|im_end|>\n' }}
         | 
| 40 | 
            +
                {%- elif message.role == "tool" %}
         | 
| 41 | 
            +
                    {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
         | 
| 42 | 
            +
                        {{- '<|im_start|>user' }}
         | 
| 43 | 
            +
                    {%- endif %}
         | 
| 44 | 
            +
                    {{- '\n<tool_response>\n' }}
         | 
| 45 | 
            +
                    {{- message.content }}
         | 
| 46 | 
            +
                    {{- '\n</tool_response>' }}
         | 
| 47 | 
            +
                    {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
         | 
| 48 | 
            +
                        {{- '<|im_end|>\n' }}
         | 
| 49 | 
            +
                    {%- endif %}
         | 
| 50 | 
            +
                {%- endif %}
         | 
| 51 | 
            +
            {%- endfor %}
         | 
| 52 | 
            +
            {%- if add_generation_prompt %}
         | 
| 53 | 
            +
                {{- '<|im_start|>assistant\n' }}
         | 
| 54 | 
            +
            {%- endif %}
         | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "architectures": [
         | 
| 3 | 
            +
                "DreamModel"
         | 
| 4 | 
            +
              ],
         | 
| 5 | 
            +
              "attention_dropout": 0.0,
         | 
| 6 | 
            +
              "auto_map": {
         | 
| 7 | 
            +
                "AutoConfig": "configuration_dream.DreamConfig",
         | 
| 8 | 
            +
                "AutoModel": "modeling_dream.DreamModel"
         | 
| 9 | 
            +
              },
         | 
| 10 | 
            +
              "bos_token_id": 151643,
         | 
| 11 | 
            +
              "eos_token_id": 151643,
         | 
| 12 | 
            +
              "hidden_act": "silu",
         | 
| 13 | 
            +
              "hidden_size": 3584,
         | 
| 14 | 
            +
              "initializer_range": 0.02,
         | 
| 15 | 
            +
              "intermediate_size": 18944,
         | 
| 16 | 
            +
              "mask_token_id": 151666,
         | 
| 17 | 
            +
              "max_position_embeddings": 131072,
         | 
| 18 | 
            +
              "max_window_layers": 28,
         | 
| 19 | 
            +
              "model_type": "Dream",
         | 
| 20 | 
            +
              "num_attention_heads": 28,
         | 
| 21 | 
            +
              "num_hidden_layers": 28,
         | 
| 22 | 
            +
              "num_key_value_heads": 4,
         | 
| 23 | 
            +
              "pad_token_id": 151643,
         | 
| 24 | 
            +
              "rms_norm_eps": 1e-06,
         | 
| 25 | 
            +
              "rope_scaling": null,
         | 
| 26 | 
            +
              "rope_theta": 1000000.0,
         | 
| 27 | 
            +
              "sliding_window": null,
         | 
| 28 | 
            +
              "tie_word_embeddings": false,
         | 
| 29 | 
            +
              "torch_dtype": "bfloat16",
         | 
| 30 | 
            +
              "transformers_version": "4.52.0.dev0",
         | 
| 31 | 
            +
              "use_cache": true,
         | 
| 32 | 
            +
              "use_mrope": false,
         | 
| 33 | 
            +
              "use_sliding_window": false,
         | 
| 34 | 
            +
              "vocab_size": 152064
         | 
| 35 | 
            +
            }
         | 
    	
        configuration_dream.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            """Dream model configuration"""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from transformers.configuration_utils import PretrainedConfig
         | 
| 18 | 
            +
            from transformers.modeling_rope_utils import rope_config_validation
         | 
| 19 | 
            +
            from transformers.utils import logging
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class DreamConfig(PretrainedConfig):
         | 
| 26 | 
            +
                model_type = "Dream"
         | 
| 27 | 
            +
                keys_to_ignore_at_inference = ["past_key_values"]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __init__(
         | 
| 30 | 
            +
                    self,
         | 
| 31 | 
            +
                    vocab_size=151936,
         | 
| 32 | 
            +
                    hidden_size=4096,
         | 
| 33 | 
            +
                    intermediate_size=22016,
         | 
| 34 | 
            +
                    num_hidden_layers=32,
         | 
| 35 | 
            +
                    num_attention_heads=32,
         | 
| 36 | 
            +
                    num_key_value_heads=32,
         | 
| 37 | 
            +
                    hidden_act="silu",
         | 
| 38 | 
            +
                    max_position_embeddings=32768,
         | 
| 39 | 
            +
                    initializer_range=0.02,
         | 
| 40 | 
            +
                    rms_norm_eps=1e-6,
         | 
| 41 | 
            +
                    use_cache=False,  # cache not used in diffusion
         | 
| 42 | 
            +
                    tie_word_embeddings=False,
         | 
| 43 | 
            +
                    rope_theta=10000.0,
         | 
| 44 | 
            +
                    rope_scaling=None,
         | 
| 45 | 
            +
                    use_sliding_window=False,
         | 
| 46 | 
            +
                    sliding_window=4096,
         | 
| 47 | 
            +
                    max_window_layers=28,
         | 
| 48 | 
            +
                    attention_dropout=0.0,
         | 
| 49 | 
            +
                    mask_token_id=151666,
         | 
| 50 | 
            +
                    pad_token_id=151643,
         | 
| 51 | 
            +
                    **kwargs,
         | 
| 52 | 
            +
                ):
         | 
| 53 | 
            +
                    self.vocab_size = vocab_size
         | 
| 54 | 
            +
                    self.max_position_embeddings = max_position_embeddings
         | 
| 55 | 
            +
                    self.hidden_size = hidden_size
         | 
| 56 | 
            +
                    self.intermediate_size = intermediate_size
         | 
| 57 | 
            +
                    self.num_hidden_layers = num_hidden_layers
         | 
| 58 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 59 | 
            +
                    self.use_sliding_window = use_sliding_window
         | 
| 60 | 
            +
                    self.sliding_window = sliding_window if use_sliding_window else None
         | 
| 61 | 
            +
                    self.max_window_layers = max_window_layers
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    # for backward compatibility
         | 
| 64 | 
            +
                    if num_key_value_heads is None:
         | 
| 65 | 
            +
                        num_key_value_heads = num_attention_heads
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    self.num_key_value_heads = num_key_value_heads
         | 
| 68 | 
            +
                    self.hidden_act = hidden_act
         | 
| 69 | 
            +
                    self.initializer_range = initializer_range
         | 
| 70 | 
            +
                    self.rms_norm_eps = rms_norm_eps
         | 
| 71 | 
            +
                    self.use_cache = use_cache
         | 
| 72 | 
            +
                    self.rope_theta = rope_theta
         | 
| 73 | 
            +
                    self.rope_scaling = rope_scaling
         | 
| 74 | 
            +
                    self.attention_dropout = attention_dropout
         | 
| 75 | 
            +
                    # Validate the correctness of rotary position embeddings parameters
         | 
| 76 | 
            +
                    # BC: if there is a 'type' field, move it to 'rope_type'.
         | 
| 77 | 
            +
                    if self.rope_scaling is not None and "type" in self.rope_scaling:
         | 
| 78 | 
            +
                        self.rope_scaling["rope_type"] = self.rope_scaling["type"]
         | 
| 79 | 
            +
                    rope_config_validation(self)
         | 
| 80 | 
            +
                    
         | 
| 81 | 
            +
                    super().__init__(
         | 
| 82 | 
            +
                        tie_word_embeddings=tie_word_embeddings,
         | 
| 83 | 
            +
                        **kwargs,
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    self.mask_token_id = mask_token_id
         | 
| 86 | 
            +
                    self.pad_token_id = pad_token_id
         | 
    	
        generation_config.json
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_from_model_config": true,
         | 
| 3 | 
            +
              "alg": "origin",
         | 
| 4 | 
            +
              "alg_temp": null,
         | 
| 5 | 
            +
              "bos_token_id": 151643,
         | 
| 6 | 
            +
              "eos_token_id": 151643,
         | 
| 7 | 
            +
              "eps": 0.001,
         | 
| 8 | 
            +
              "mask_token_id": null,
         | 
| 9 | 
            +
              "output_history": false,
         | 
| 10 | 
            +
              "pad_token_id": 151643,
         | 
| 11 | 
            +
              "steps": 512,
         | 
| 12 | 
            +
              "temperature": 0.0,
         | 
| 13 | 
            +
              "top_k": null,
         | 
| 14 | 
            +
              "top_p": null,
         | 
| 15 | 
            +
              "transformers_version": "4.52.0.dev0"
         | 
| 16 | 
            +
            }
         | 
    	
        generation_utils.py
    ADDED
    
    | @@ -0,0 +1,463 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import warnings
         | 
| 17 | 
            +
            import copy
         | 
| 18 | 
            +
            from dataclasses import dataclass
         | 
| 19 | 
            +
            from typing import Any, Dict, Optional, Tuple, Union
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            import torch.distributions as dists
         | 
| 23 | 
            +
            from torch.nn import functional as F
         | 
| 24 | 
            +
            from transformers import __version__
         | 
| 25 | 
            +
            from transformers.generation.configuration_utils import (
         | 
| 26 | 
            +
                GenerationConfig
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from transformers.utils import (
         | 
| 29 | 
            +
                ModelOutput,
         | 
| 30 | 
            +
                is_torchdynamo_compiling,
         | 
| 31 | 
            +
                logging,
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def top_p_logits(logits, top_p=None):
         | 
| 38 | 
            +
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
         | 
| 39 | 
            +
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
         | 
| 40 | 
            +
                sorted_indices_to_remove = cumulative_probs > top_p
         | 
| 41 | 
            +
                # Shift the indices to the right to keep the first token above the threshold
         | 
| 42 | 
            +
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
         | 
| 43 | 
            +
                sorted_indices_to_remove[..., 0] = 0
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
         | 
| 46 | 
            +
                mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
         | 
| 47 | 
            +
                logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
         | 
| 48 | 
            +
                return logits
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def top_k_logits(logits, top_k=None):
         | 
| 51 | 
            +
                top_k = min(top_k, logits.size(-1))  # Safety check
         | 
| 52 | 
            +
                # Remove all tokens with a probability less than the last token of the top-k
         | 
| 53 | 
            +
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
         | 
| 54 | 
            +
                logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
         | 
| 55 | 
            +
                return logits
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if temperature > 0:
         | 
| 61 | 
            +
                    logits = logits / temperature
         | 
| 62 | 
            +
                if top_p is not None and top_p < 1:
         | 
| 63 | 
            +
                    logits = top_p_logits(logits, top_p)
         | 
| 64 | 
            +
                if top_k is not None:
         | 
| 65 | 
            +
                    logits = top_k_logits(logits, top_k)
         | 
| 66 | 
            +
                probs = torch.softmax(logits, dim=-1)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                if temperature > 0:
         | 
| 69 | 
            +
                    try:
         | 
| 70 | 
            +
                        x0 = dists.Categorical(probs=probs).sample()
         | 
| 71 | 
            +
                        confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
         | 
| 72 | 
            +
                    except:
         | 
| 73 | 
            +
                        confidence, x0 = probs.max(dim=-1)
         | 
| 74 | 
            +
                else:
         | 
| 75 | 
            +
                    confidence, x0 = probs.max(dim=-1)
         | 
| 76 | 
            +
                
         | 
| 77 | 
            +
                if margin_confidence:
         | 
| 78 | 
            +
                    sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
         | 
| 79 | 
            +
                    # Extract top1 and top2 probabilities
         | 
| 80 | 
            +
                    top1_probs = sorted_probs[:, 0] 
         | 
| 81 | 
            +
                    top2_probs = sorted_probs[:, 1] 
         | 
| 82 | 
            +
                    # Calculate confidence as top1 - top2
         | 
| 83 | 
            +
                    confidence = top1_probs - top2_probs 
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                if neg_entropy:
         | 
| 86 | 
            +
                    epsilon = 1e-10
         | 
| 87 | 
            +
                    log_probs = torch.log(probs + epsilon)
         | 
| 88 | 
            +
                    confidence = torch.sum(probs * log_probs, dim=-1)
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                return confidence, x0
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            @dataclass
         | 
| 94 | 
            +
            class DreamModelOutput(ModelOutput):
         | 
| 95 | 
            +
                sequences: torch.LongTensor = None
         | 
| 96 | 
            +
                history: Optional[Tuple[torch.FloatTensor]] = None
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class DreamGenerationConfig(GenerationConfig):
         | 
| 100 | 
            +
                def __init__(self, **kwargs):
         | 
| 101 | 
            +
                    self.temperature: float = kwargs.pop("temperature", 0.0)
         | 
| 102 | 
            +
                    self.top_p: Optional[float] = kwargs.pop("top_p", None)
         | 
| 103 | 
            +
                    self.top_k: Optional[int] = kwargs.pop("top_k", None)
         | 
| 104 | 
            +
                    self.max_length = kwargs.pop("max_length", 20)
         | 
| 105 | 
            +
                    self.max_new_tokens = kwargs.pop("max_new_tokens", None)
         | 
| 106 | 
            +
                    # diffusion specific params
         | 
| 107 | 
            +
                    self.eps: float = kwargs.pop("eps", 1e-3)
         | 
| 108 | 
            +
                    self.steps: int = kwargs.pop("steps", 512)
         | 
| 109 | 
            +
                    self.alg: str = kwargs.pop("alg", 'origin')
         | 
| 110 | 
            +
                    self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # Parameters that define the output variables of `generate`
         | 
| 113 | 
            +
                    self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
         | 
| 114 | 
            +
                    self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
         | 
| 115 | 
            +
                    self.output_history: bool = kwargs.pop("output_history", False)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    # Special tokens that can be used at generation time
         | 
| 118 | 
            +
                    self.mask_token_id = kwargs.pop("mask_token_id", None)
         | 
| 119 | 
            +
                    self.pad_token_id = kwargs.pop("pad_token_id", None)
         | 
| 120 | 
            +
                    self.bos_token_id = kwargs.pop("bos_token_id", None)
         | 
| 121 | 
            +
                    self.eos_token_id = kwargs.pop("eos_token_id", None)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # Wild card
         | 
| 124 | 
            +
                    self.generation_kwargs = kwargs.pop("generation_kwargs", {})
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
         | 
| 127 | 
            +
                    # interface.
         | 
| 128 | 
            +
                    self._from_model_config = kwargs.pop("_from_model_config", False)
         | 
| 129 | 
            +
                    self._commit_hash = kwargs.pop("_commit_hash", None)
         | 
| 130 | 
            +
                    self.transformers_version = kwargs.pop("transformers_version", __version__)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # Additional attributes without default values
         | 
| 133 | 
            +
                    if not self._from_model_config:
         | 
| 134 | 
            +
                        # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
         | 
| 135 | 
            +
                        # model's default configuration file
         | 
| 136 | 
            +
                        for key, value in kwargs.items():
         | 
| 137 | 
            +
                            try:
         | 
| 138 | 
            +
                                setattr(self, key, value)
         | 
| 139 | 
            +
                            except AttributeError as err:
         | 
| 140 | 
            +
                                logger.error(f"Can't set {key} with value {value} for {self}")
         | 
| 141 | 
            +
                                raise err
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # Validate the values of the attributes
         | 
| 144 | 
            +
                    self.validate(is_init=True)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def validate(self, is_init=False, strict=True):
         | 
| 147 | 
            +
                    pass
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            class DreamGenerationMixin:
         | 
| 150 | 
            +
                @staticmethod
         | 
| 151 | 
            +
                def _expand_inputs_for_generation(
         | 
| 152 | 
            +
                    expand_size: int = 1,
         | 
| 153 | 
            +
                    input_ids: Optional[torch.LongTensor] = None,
         | 
| 154 | 
            +
                    attention_mask: Optional[torch.LongTensor] = None
         | 
| 155 | 
            +
                ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
         | 
| 156 | 
            +
                    """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
         | 
| 157 | 
            +
                    # Do not call torch.repeat_interleave if expand_size is 1 because it clones
         | 
| 158 | 
            +
                    # the input tensor and thus requires more memory although no change is applied
         | 
| 159 | 
            +
                    if expand_size == 1:
         | 
| 160 | 
            +
                        return input_ids, attention_mask
         | 
| 161 | 
            +
                    if input_ids is not None:
         | 
| 162 | 
            +
                        input_ids = input_ids.repeat_interleave(expand_size, dim=0)
         | 
| 163 | 
            +
                    if attention_mask is not None:
         | 
| 164 | 
            +
                        attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
         | 
| 165 | 
            +
                    return input_ids, attention_mask
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
         | 
| 168 | 
            +
                    """Performs validation related to the resulting generated length"""
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # Can't throw warnings/exceptions during compilation
         | 
| 171 | 
            +
                    if is_torchdynamo_compiling():
         | 
| 172 | 
            +
                        return
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # 1. Max length warnings related to poor parameterization
         | 
| 175 | 
            +
                    if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
         | 
| 176 | 
            +
                        # 20 is the default max_length of the generation config
         | 
| 177 | 
            +
                        warnings.warn(
         | 
| 178 | 
            +
                            f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
         | 
| 179 | 
            +
                            "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
         | 
| 180 | 
            +
                            "generation.",
         | 
| 181 | 
            +
                            UserWarning,
         | 
| 182 | 
            +
                        )
         | 
| 183 | 
            +
                    if input_ids_length >= generation_config.max_length:
         | 
| 184 | 
            +
                        input_ids_string = "input_ids"
         | 
| 185 | 
            +
                        raise ValueError(
         | 
| 186 | 
            +
                            f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
         | 
| 187 | 
            +
                            f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
         | 
| 188 | 
            +
                            " increasing `max_length` or, better yet, setting `max_new_tokens`."
         | 
| 189 | 
            +
                        )
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def _prepare_generated_length(
         | 
| 192 | 
            +
                    self,
         | 
| 193 | 
            +
                    generation_config,
         | 
| 194 | 
            +
                    has_default_max_length,
         | 
| 195 | 
            +
                    input_ids_length,
         | 
| 196 | 
            +
                ):
         | 
| 197 | 
            +
                    """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if generation_config.max_new_tokens is not None:
         | 
| 200 | 
            +
                        if not has_default_max_length and generation_config.max_length is not None:
         | 
| 201 | 
            +
                            logger.warning(
         | 
| 202 | 
            +
                                f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
         | 
| 203 | 
            +
                                f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
         | 
| 204 | 
            +
                                "Please refer to the documentation for more information. "
         | 
| 205 | 
            +
                                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
         | 
| 206 | 
            +
                            )
         | 
| 207 | 
            +
                        generation_config.max_length = generation_config.max_new_tokens + input_ids_length
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    elif has_default_max_length:
         | 
| 210 | 
            +
                        if generation_config.max_length == DreamGenerationConfig().max_length:
         | 
| 211 | 
            +
                            generation_config.max_length = generation_config.max_length + input_ids_length
         | 
| 212 | 
            +
                            max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
         | 
| 213 | 
            +
                            if max_position_embeddings is not None:
         | 
| 214 | 
            +
                                generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    return generation_config
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                def _prepare_generation_config(
         | 
| 219 | 
            +
                    self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
         | 
| 220 | 
            +
                ) -> DreamGenerationConfig:
         | 
| 221 | 
            +
                    """
         | 
| 222 | 
            +
                    Prepares the base generation config, then applies any generation configuration options from kwargs. This
         | 
| 223 | 
            +
                    function handles retrocompatibility with respect to configuration files.
         | 
| 224 | 
            +
                    """
         | 
| 225 | 
            +
                    # priority: `generation_config` argument > `model.generation_config` (the default generation config)
         | 
| 226 | 
            +
                    using_model_generation_config = False
         | 
| 227 | 
            +
                    if generation_config is None:
         | 
| 228 | 
            +
                        generation_config = DreamGenerationConfig.from_model_config(self.config)
         | 
| 229 | 
            +
                        using_model_generation_config = True
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
         | 
| 232 | 
            +
                    # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
         | 
| 233 | 
            +
                    # exception will be raised in `_validate_model_kwargs`
         | 
| 234 | 
            +
                    if not is_torchdynamo_compiling():
         | 
| 235 | 
            +
                        generation_config = copy.deepcopy(generation_config)
         | 
| 236 | 
            +
                        _kwargs = generation_config.update(**kwargs)
         | 
| 237 | 
            +
                        # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
         | 
| 238 | 
            +
                        if not using_model_generation_config:
         | 
| 239 | 
            +
                            if generation_config.bos_token_id is None:
         | 
| 240 | 
            +
                                generation_config.bos_token_id = self.generation_config.bos_token_id
         | 
| 241 | 
            +
                            if generation_config.eos_token_id is None:
         | 
| 242 | 
            +
                                generation_config.eos_token_id = self.generation_config.eos_token_id
         | 
| 243 | 
            +
                            if generation_config.pad_token_id is None:
         | 
| 244 | 
            +
                                generation_config.pad_token_id = self.generation_config.pad_token_id
         | 
| 245 | 
            +
                            if generation_config.mask_token_id is None:
         | 
| 246 | 
            +
                                generation_config.mask_token_id = self.generation_config.mask_token_id
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    return generation_config
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def _prepare_special_tokens(
         | 
| 251 | 
            +
                    self,
         | 
| 252 | 
            +
                    generation_config: DreamGenerationConfig,
         | 
| 253 | 
            +
                    device: Optional[Union[torch.device, str]] = None,
         | 
| 254 | 
            +
                ):
         | 
| 255 | 
            +
                    """
         | 
| 256 | 
            +
                    Prepares the special tokens for generation, overwriting the generation config with their processed versions
         | 
| 257 | 
            +
                    converted to tensor.
         | 
| 258 | 
            +
                    Note that `generation_config` is changed in place and stops being serializable after this method is called.
         | 
| 259 | 
            +
                    That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
         | 
| 260 | 
            +
                    function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
         | 
| 261 | 
            +
                    """
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    # Convert special tokens to tensors
         | 
| 264 | 
            +
                    def _tensor_or_none(token, device=None):
         | 
| 265 | 
            +
                        if token is None:
         | 
| 266 | 
            +
                            return token
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                        device = device if device is not None else self.device
         | 
| 269 | 
            +
                        if isinstance(token, torch.Tensor):
         | 
| 270 | 
            +
                            return token.to(device)
         | 
| 271 | 
            +
                        return torch.tensor(token, device=device, dtype=torch.long)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
         | 
| 274 | 
            +
                    eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
         | 
| 275 | 
            +
                    pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
         | 
| 276 | 
            +
                    mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
         | 
| 279 | 
            +
                    if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
         | 
| 280 | 
            +
                        eos_token_tensor = eos_token_tensor.unsqueeze(0)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # Set pad token if unset (and there are conditions to do so)
         | 
| 283 | 
            +
                    if pad_token_tensor is None and eos_token_tensor is not None:
         | 
| 284 | 
            +
                        pad_token_tensor = eos_token_tensor[0]
         | 
| 285 | 
            +
                        logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # Update generation config with the updated special tokens tensors
         | 
| 288 | 
            +
                    # NOTE: this must be written into a different attribute name than the one holding the original special tokens
         | 
| 289 | 
            +
                    # (in their non-tensor form), in order to enable end-to-end compilation. See
         | 
| 290 | 
            +
                    # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
         | 
| 291 | 
            +
                    generation_config._bos_token_tensor = bos_token_tensor
         | 
| 292 | 
            +
                    generation_config._eos_token_tensor = eos_token_tensor
         | 
| 293 | 
            +
                    generation_config._pad_token_tensor = pad_token_tensor
         | 
| 294 | 
            +
                    generation_config._mask_token_tensor = mask_token_tensor
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                @torch.no_grad()
         | 
| 297 | 
            +
                def diffusion_generate(
         | 
| 298 | 
            +
                    self,
         | 
| 299 | 
            +
                    inputs: Optional[torch.Tensor] = None,
         | 
| 300 | 
            +
                    generation_config: Optional[DreamGenerationConfig] = None,
         | 
| 301 | 
            +
                    **kwargs,
         | 
| 302 | 
            +
                ) -> Union[DreamModelOutput, torch.LongTensor]:
         | 
| 303 | 
            +
                    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
         | 
| 304 | 
            +
                    generation_config = self._prepare_generation_config(generation_config, **kwargs)
         | 
| 305 | 
            +
                    generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
         | 
| 306 | 
            +
                    generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # 2. Define model inputs
         | 
| 309 | 
            +
                    assert inputs is not None
         | 
| 310 | 
            +
                    input_ids = inputs
         | 
| 311 | 
            +
                    device = input_ids.device
         | 
| 312 | 
            +
                    attention_mask = kwargs.pop("attention_mask", None)
         | 
| 313 | 
            +
                    self._prepare_special_tokens(generation_config, device=device)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    # 3. Prepare `max_length`.
         | 
| 316 | 
            +
                    input_ids_length = input_ids.shape[-1]
         | 
| 317 | 
            +
                    has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
         | 
| 318 | 
            +
                    generation_config = self._prepare_generated_length(
         | 
| 319 | 
            +
                        generation_config=generation_config,
         | 
| 320 | 
            +
                        has_default_max_length=has_default_max_length,
         | 
| 321 | 
            +
                        input_ids_length=input_ids_length,
         | 
| 322 | 
            +
                    )
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
         | 
| 325 | 
            +
                    
         | 
| 326 | 
            +
                    # 4. Check input_ids
         | 
| 327 | 
            +
                    if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
         | 
| 328 | 
            +
                        warnings.warn(
         | 
| 329 | 
            +
                            "You are calling .generate() with the `input_ids` being on a device type different"
         | 
| 330 | 
            +
                            f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
         | 
| 331 | 
            +
                            f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
         | 
| 332 | 
            +
                            " Please make sure that you have put `input_ids` to the"
         | 
| 333 | 
            +
                            f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
         | 
| 334 | 
            +
                            " running `.generate()`.",
         | 
| 335 | 
            +
                            UserWarning,
         | 
| 336 | 
            +
                        )
         | 
| 337 | 
            +
                    if (
         | 
| 338 | 
            +
                        hasattr(generation_config, "pad_token_id") and
         | 
| 339 | 
            +
                        torch.any(input_ids == generation_config.pad_token_id) and 
         | 
| 340 | 
            +
                        attention_mask is None
         | 
| 341 | 
            +
                    ):
         | 
| 342 | 
            +
                        warnings.warn(
         | 
| 343 | 
            +
                            "Padding was detected but no attention mask is passed here. For correct "
         | 
| 344 | 
            +
                            "generation results, please set `attention_mask` when batch-padding inputs.",
         | 
| 345 | 
            +
                            UserWarning,
         | 
| 346 | 
            +
                        )
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    input_ids, attention_mask = self._expand_inputs_for_generation(
         | 
| 349 | 
            +
                        expand_size=generation_config.num_return_sequences,
         | 
| 350 | 
            +
                        input_ids=input_ids,
         | 
| 351 | 
            +
                        attention_mask=attention_mask 
         | 
| 352 | 
            +
                    )
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    result = self._sample(
         | 
| 355 | 
            +
                        input_ids,
         | 
| 356 | 
            +
                        attention_mask=attention_mask,
         | 
| 357 | 
            +
                        generation_config=generation_config,
         | 
| 358 | 
            +
                        generation_tokens_hook_func=generation_tokens_hook_func,
         | 
| 359 | 
            +
                        generation_logits_hook_func=generation_logits_hook_func
         | 
| 360 | 
            +
                    )
         | 
| 361 | 
            +
                    return result
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def _sample(
         | 
| 364 | 
            +
                    self,
         | 
| 365 | 
            +
                    input_ids: torch.LongTensor,
         | 
| 366 | 
            +
                    attention_mask: Optional[torch.LongTensor],
         | 
| 367 | 
            +
                    generation_config: DreamGenerationConfig,
         | 
| 368 | 
            +
                    generation_tokens_hook_func,
         | 
| 369 | 
            +
                    generation_logits_hook_func
         | 
| 370 | 
            +
                ) -> Union[DreamModelOutput, torch.LongTensor]:
         | 
| 371 | 
            +
                    # init values
         | 
| 372 | 
            +
                    output_history = generation_config.output_history
         | 
| 373 | 
            +
                    return_dict_in_generate = generation_config.return_dict_in_generate
         | 
| 374 | 
            +
                    max_length = generation_config.max_length
         | 
| 375 | 
            +
                    mask_token_id = generation_config.mask_token_id
         | 
| 376 | 
            +
                    steps = generation_config.steps
         | 
| 377 | 
            +
                    eps = 1e-12
         | 
| 378 | 
            +
                    alg = generation_config.alg
         | 
| 379 | 
            +
                    alg_temp = generation_config.alg_temp
         | 
| 380 | 
            +
                    temperature = generation_config.temperature
         | 
| 381 | 
            +
                    top_p = generation_config.top_p
         | 
| 382 | 
            +
                    top_k = generation_config.top_k
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    histories = [] if (return_dict_in_generate and output_history) else None
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    # pad input_ids to max_length
         | 
| 387 | 
            +
                    x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    if attention_mask is not None and torch.any(attention_mask == 0.0):
         | 
| 390 | 
            +
                        # we do not mask the [MASK] tokens so value = 1.0
         | 
| 391 | 
            +
                        attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
         | 
| 392 | 
            +
                        tok_idx = attention_mask.long().cumsum(-1) - 1
         | 
| 393 | 
            +
                        tok_idx.masked_fill_(attention_mask == 0, 1)
         | 
| 394 | 
            +
                        # attention_mask is of shape [B, N]
         | 
| 395 | 
            +
                        # broadcast to [B, 1, N, N]
         | 
| 396 | 
            +
                        attention_mask = torch.logical_and(
         | 
| 397 | 
            +
                            attention_mask.unsqueeze(1).unsqueeze(-2),
         | 
| 398 | 
            +
                            attention_mask.unsqueeze(1).unsqueeze(-1),
         | 
| 399 | 
            +
                        )
         | 
| 400 | 
            +
                    else:
         | 
| 401 | 
            +
                        tok_idx = None
         | 
| 402 | 
            +
                        attention_mask = "full"
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    # this allows user-defined token control of the intermediate steps
         | 
| 407 | 
            +
                    x = generation_tokens_hook_func(None, x, None)
         | 
| 408 | 
            +
                    for i in range(steps):
         | 
| 409 | 
            +
                        mask_index = (x == mask_token_id)
         | 
| 410 | 
            +
                        logits = self(x, attention_mask, tok_idx).logits
         | 
| 411 | 
            +
                        logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                        # this allows user-defined logits control of the intermediate steps
         | 
| 414 | 
            +
                        logits = generation_logits_hook_func(i, x, logits)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        mask_logits = logits[mask_index]
         | 
| 417 | 
            +
                        t = timesteps[i]
         | 
| 418 | 
            +
                        s = timesteps[i + 1]
         | 
| 419 | 
            +
                    
         | 
| 420 | 
            +
                        if alg == 'origin':
         | 
| 421 | 
            +
                            p_transfer = 1 - s / t if i < steps - 1 else 1
         | 
| 422 | 
            +
                            x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
         | 
| 423 | 
            +
                            transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
         | 
| 424 | 
            +
                            _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
         | 
| 425 | 
            +
                            x[mask_index] = x0.clone()
         | 
| 426 | 
            +
                        else:
         | 
| 427 | 
            +
                            if alg == 'maskgit_plus':
         | 
| 428 | 
            +
                                confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
         | 
| 429 | 
            +
                            elif alg == 'topk_margin':
         | 
| 430 | 
            +
                                confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
         | 
| 431 | 
            +
                            elif alg == 'entropy':
         | 
| 432 | 
            +
                                confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
         | 
| 433 | 
            +
                            else:
         | 
| 434 | 
            +
                                raise RuntimeError(f"Unknown alg: {alg}")
         | 
| 435 | 
            +
                            num_mask_token = mask_index.sum() / mask_index.shape[0]
         | 
| 436 | 
            +
                            number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
         | 
| 437 | 
            +
                            full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
         | 
| 438 | 
            +
                            full_confidence[mask_index] = confidence
         | 
| 439 | 
            +
                            if number_transfer_tokens > 0:
         | 
| 440 | 
            +
                                if alg_temp is None or alg_temp == 0:
         | 
| 441 | 
            +
                                    _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
         | 
| 442 | 
            +
                                else:
         | 
| 443 | 
            +
                                    full_confidence = full_confidence / alg_temp
         | 
| 444 | 
            +
                                    full_confidence = F.softmax(full_confidence, dim=-1)
         | 
| 445 | 
            +
                                    transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
         | 
| 446 | 
            +
                                x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
         | 
| 447 | 
            +
                                x_[mask_index] = x0.clone()
         | 
| 448 | 
            +
                                row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
         | 
| 449 | 
            +
                                x[row_indices,transfer_index] = x_[row_indices,transfer_index]
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                        # this allows user-defined token control of the intermediate steps
         | 
| 452 | 
            +
                        x = generation_tokens_hook_func(i, x, logits)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                        if histories is not None:
         | 
| 455 | 
            +
                            histories.append(x.clone())
         | 
| 456 | 
            +
                    
         | 
| 457 | 
            +
                    if return_dict_in_generate:
         | 
| 458 | 
            +
                        return DreamModelOutput(
         | 
| 459 | 
            +
                            sequences=x,
         | 
| 460 | 
            +
                            history=histories,
         | 
| 461 | 
            +
                        )
         | 
| 462 | 
            +
                    else:
         | 
| 463 | 
            +
                        return x
         | 
    	
        merges.txt
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        model-00001-of-00004.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:fabcd922793cfd77755050075fdd1ff493b19e20877c7efa378205f116584a0b
         | 
| 3 | 
            +
            size 4877660776
         | 
    	
        model-00002-of-00004.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:2c4b343fd8d668a40aee7aa824346075a02d1bdb81f354501fc37c88bff59786
         | 
| 3 | 
            +
            size 4932751008
         | 
    	
        model-00003-of-00004.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:96117035fa26be70e5ef6be645a2e8070390f59857f8d2e3d14407125c983cfd
         | 
| 3 | 
            +
            size 4330865200
         | 
    	
        model-00004-of-00004.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:7473eb9f41ca6c371b9e7784497311ea2eb82591021d40419129d75a1b61a6f1
         | 
| 3 | 
            +
            size 1089994880
         | 
    	
        model.safetensors.index.json
    ADDED
    
    | @@ -0,0 +1,346 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "metadata": {
         | 
| 3 | 
            +
                "total_size": 15231233024
         | 
| 4 | 
            +
              },
         | 
| 5 | 
            +
              "weight_map": {
         | 
| 6 | 
            +
                "lm_head.weight": "model-00004-of-00004.safetensors",
         | 
| 7 | 
            +
                "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
         | 
| 8 | 
            +
                "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 9 | 
            +
                "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 10 | 
            +
                "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 11 | 
            +
                "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 12 | 
            +
                "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 13 | 
            +
                "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 14 | 
            +
                "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 15 | 
            +
                "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 16 | 
            +
                "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 17 | 
            +
                "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 18 | 
            +
                "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 19 | 
            +
                "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 20 | 
            +
                "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 21 | 
            +
                "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 22 | 
            +
                "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 23 | 
            +
                "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 24 | 
            +
                "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 25 | 
            +
                "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 26 | 
            +
                "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 27 | 
            +
                "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 28 | 
            +
                "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 29 | 
            +
                "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 30 | 
            +
                "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 31 | 
            +
                "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 32 | 
            +
                "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 33 | 
            +
                "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 34 | 
            +
                "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 35 | 
            +
                "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 36 | 
            +
                "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 37 | 
            +
                "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 38 | 
            +
                "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 39 | 
            +
                "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 40 | 
            +
                "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 41 | 
            +
                "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 42 | 
            +
                "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 43 | 
            +
                "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 44 | 
            +
                "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 45 | 
            +
                "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 46 | 
            +
                "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 47 | 
            +
                "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 48 | 
            +
                "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 49 | 
            +
                "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 50 | 
            +
                "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 51 | 
            +
                "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 52 | 
            +
                "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 53 | 
            +
                "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 54 | 
            +
                "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 55 | 
            +
                "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 56 | 
            +
                "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 57 | 
            +
                "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 58 | 
            +
                "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 59 | 
            +
                "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 60 | 
            +
                "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 61 | 
            +
                "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 62 | 
            +
                "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 63 | 
            +
                "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 64 | 
            +
                "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 65 | 
            +
                "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 66 | 
            +
                "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 67 | 
            +
                "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 68 | 
            +
                "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 69 | 
            +
                "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 70 | 
            +
                "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 71 | 
            +
                "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 72 | 
            +
                "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 73 | 
            +
                "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 74 | 
            +
                "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 75 | 
            +
                "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 76 | 
            +
                "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 77 | 
            +
                "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 78 | 
            +
                "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 79 | 
            +
                "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 80 | 
            +
                "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 81 | 
            +
                "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 82 | 
            +
                "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 83 | 
            +
                "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 84 | 
            +
                "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 85 | 
            +
                "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 86 | 
            +
                "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 87 | 
            +
                "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 88 | 
            +
                "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 89 | 
            +
                "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 90 | 
            +
                "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 91 | 
            +
                "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 92 | 
            +
                "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 93 | 
            +
                "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 94 | 
            +
                "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 95 | 
            +
                "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 96 | 
            +
                "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 97 | 
            +
                "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 98 | 
            +
                "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 99 | 
            +
                "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 100 | 
            +
                "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 101 | 
            +
                "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 102 | 
            +
                "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 103 | 
            +
                "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 104 | 
            +
                "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 105 | 
            +
                "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 106 | 
            +
                "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 107 | 
            +
                "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 108 | 
            +
                "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 109 | 
            +
                "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 110 | 
            +
                "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 111 | 
            +
                "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 112 | 
            +
                "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 113 | 
            +
                "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 114 | 
            +
                "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 115 | 
            +
                "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 116 | 
            +
                "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 117 | 
            +
                "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 118 | 
            +
                "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 119 | 
            +
                "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 120 | 
            +
                "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 121 | 
            +
                "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 122 | 
            +
                "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 123 | 
            +
                "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 124 | 
            +
                "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 125 | 
            +
                "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 126 | 
            +
                "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 127 | 
            +
                "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 128 | 
            +
                "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 129 | 
            +
                "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 130 | 
            +
                "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 131 | 
            +
                "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 132 | 
            +
                "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 133 | 
            +
                "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 134 | 
            +
                "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 135 | 
            +
                "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 136 | 
            +
                "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 137 | 
            +
                "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 138 | 
            +
                "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 139 | 
            +
                "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 140 | 
            +
                "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 141 | 
            +
                "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 142 | 
            +
                "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 143 | 
            +
                "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 144 | 
            +
                "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 145 | 
            +
                "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 146 | 
            +
                "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 147 | 
            +
                "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 148 | 
            +
                "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 149 | 
            +
                "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 150 | 
            +
                "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 151 | 
            +
                "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 152 | 
            +
                "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 153 | 
            +
                "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 154 | 
            +
                "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 155 | 
            +
                "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 156 | 
            +
                "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 157 | 
            +
                "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 158 | 
            +
                "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 159 | 
            +
                "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 160 | 
            +
                "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 161 | 
            +
                "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 162 | 
            +
                "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 163 | 
            +
                "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 164 | 
            +
                "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 165 | 
            +
                "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 166 | 
            +
                "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 167 | 
            +
                "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 168 | 
            +
                "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 169 | 
            +
                "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 170 | 
            +
                "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 171 | 
            +
                "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 172 | 
            +
                "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 173 | 
            +
                "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 174 | 
            +
                "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 175 | 
            +
                "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 176 | 
            +
                "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 177 | 
            +
                "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 178 | 
            +
                "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 179 | 
            +
                "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 180 | 
            +
                "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 181 | 
            +
                "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 182 | 
            +
                "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 183 | 
            +
                "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 184 | 
            +
                "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 185 | 
            +
                "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 186 | 
            +
                "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 187 | 
            +
                "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 188 | 
            +
                "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 189 | 
            +
                "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 190 | 
            +
                "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 191 | 
            +
                "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 192 | 
            +
                "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 193 | 
            +
                "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 194 | 
            +
                "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 195 | 
            +
                "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 196 | 
            +
                "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 197 | 
            +
                "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 198 | 
            +
                "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 199 | 
            +
                "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 200 | 
            +
                "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 201 | 
            +
                "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 202 | 
            +
                "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 203 | 
            +
                "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 204 | 
            +
                "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 205 | 
            +
                "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 206 | 
            +
                "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 207 | 
            +
                "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 208 | 
            +
                "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 209 | 
            +
                "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 210 | 
            +
                "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 211 | 
            +
                "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 212 | 
            +
                "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 213 | 
            +
                "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 214 | 
            +
                "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 215 | 
            +
                "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 216 | 
            +
                "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 217 | 
            +
                "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 218 | 
            +
                "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 219 | 
            +
                "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 220 | 
            +
                "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 221 | 
            +
                "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 222 | 
            +
                "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 223 | 
            +
                "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 224 | 
            +
                "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 225 | 
            +
                "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 226 | 
            +
                "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 227 | 
            +
                "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 228 | 
            +
                "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 229 | 
            +
                "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 230 | 
            +
                "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 231 | 
            +
                "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 232 | 
            +
                "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 233 | 
            +
                "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 234 | 
            +
                "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 235 | 
            +
                "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 236 | 
            +
                "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 237 | 
            +
                "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 238 | 
            +
                "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 239 | 
            +
                "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 240 | 
            +
                "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 241 | 
            +
                "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 242 | 
            +
                "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 243 | 
            +
                "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 244 | 
            +
                "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 245 | 
            +
                "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 246 | 
            +
                "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 247 | 
            +
                "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 248 | 
            +
                "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 249 | 
            +
                "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 250 | 
            +
                "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 251 | 
            +
                "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 252 | 
            +
                "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
         | 
| 253 | 
            +
                "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 254 | 
            +
                "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 255 | 
            +
                "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 256 | 
            +
                "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 257 | 
            +
                "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 258 | 
            +
                "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
         | 
| 259 | 
            +
                "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
         | 
| 260 | 
            +
                "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 261 | 
            +
                "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 262 | 
            +
                "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 263 | 
            +
                "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 264 | 
            +
                "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 265 | 
            +
                "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 266 | 
            +
                "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 267 | 
            +
                "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 268 | 
            +
                "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 269 | 
            +
                "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 270 | 
            +
                "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 271 | 
            +
                "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 272 | 
            +
                "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 273 | 
            +
                "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 274 | 
            +
                "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 275 | 
            +
                "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 276 | 
            +
                "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 277 | 
            +
                "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 278 | 
            +
                "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 279 | 
            +
                "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 280 | 
            +
                "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 281 | 
            +
                "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 282 | 
            +
                "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 283 | 
            +
                "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 284 | 
            +
                "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 285 | 
            +
                "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 286 | 
            +
                "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 287 | 
            +
                "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 288 | 
            +
                "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 289 | 
            +
                "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 290 | 
            +
                "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 291 | 
            +
                "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 292 | 
            +
                "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 293 | 
            +
                "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 294 | 
            +
                "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 295 | 
            +
                "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 296 | 
            +
                "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 297 | 
            +
                "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 298 | 
            +
                "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 299 | 
            +
                "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 300 | 
            +
                "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 301 | 
            +
                "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 302 | 
            +
                "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 303 | 
            +
                "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 304 | 
            +
                "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 305 | 
            +
                "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 306 | 
            +
                "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 307 | 
            +
                "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 308 | 
            +
                "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 309 | 
            +
                "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 310 | 
            +
                "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 311 | 
            +
                "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 312 | 
            +
                "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
         | 
| 313 | 
            +
                "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 314 | 
            +
                "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 315 | 
            +
                "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 316 | 
            +
                "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 317 | 
            +
                "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 318 | 
            +
                "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 319 | 
            +
                "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 320 | 
            +
                "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 321 | 
            +
                "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 322 | 
            +
                "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 323 | 
            +
                "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 324 | 
            +
                "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 325 | 
            +
                "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 326 | 
            +
                "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 327 | 
            +
                "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 328 | 
            +
                "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 329 | 
            +
                "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 330 | 
            +
                "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
         | 
| 331 | 
            +
                "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
         | 
| 332 | 
            +
                "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 333 | 
            +
                "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 334 | 
            +
                "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 335 | 
            +
                "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 336 | 
            +
                "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
         | 
| 337 | 
            +
                "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 338 | 
            +
                "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 339 | 
            +
                "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 340 | 
            +
                "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 341 | 
            +
                "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 342 | 
            +
                "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
         | 
| 343 | 
            +
                "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
         | 
| 344 | 
            +
                "model.norm.weight": "model-00003-of-00004.safetensors"
         | 
| 345 | 
            +
              }
         | 
| 346 | 
            +
            }
         | 
    	
        modeling_dream.py
    ADDED
    
    | @@ -0,0 +1,824 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
         | 
| 5 | 
            +
            # and OPT and Qwen implementations in this library. It has been modified from its
         | 
| 6 | 
            +
            # original forms to accommodate minor architectural differences compared
         | 
| 7 | 
            +
            # to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 10 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 11 | 
            +
            # You may obtain a copy of the License at
         | 
| 12 | 
            +
            #
         | 
| 13 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 16 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 17 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 18 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 19 | 
            +
            # limitations under the License.
         | 
| 20 | 
            +
            """PyTorch Dream model."""
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import math
         | 
| 23 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 24 | 
            +
            import os
         | 
| 25 | 
            +
            import torch
         | 
| 26 | 
            +
            import torch.utils.checkpoint
         | 
| 27 | 
            +
            from torch import nn
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            from transformers.activations import ACT2FN
         | 
| 30 | 
            +
            from transformers.cache_utils import Cache, DynamicCache
         | 
| 31 | 
            +
            from transformers.modeling_outputs import (
         | 
| 32 | 
            +
                BaseModelOutput,
         | 
| 33 | 
            +
                MaskedLMOutput,
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
            from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
         | 
| 36 | 
            +
            from transformers.modeling_utils import PreTrainedModel
         | 
| 37 | 
            +
            from transformers.utils import (
         | 
| 38 | 
            +
                add_start_docstrings,
         | 
| 39 | 
            +
                add_start_docstrings_to_model_forward,
         | 
| 40 | 
            +
                is_flash_attn_2_available,
         | 
| 41 | 
            +
                is_flash_attn_greater_or_equal_2_10,
         | 
| 42 | 
            +
                logging,
         | 
| 43 | 
            +
            )
         | 
| 44 | 
            +
            from transformers import PretrainedConfig
         | 
| 45 | 
            +
            from .configuration_dream import DreamConfig
         | 
| 46 | 
            +
            from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            if is_flash_attn_2_available():
         | 
| 49 | 
            +
                from transformers.modeling_flash_attention_utils import _flash_attention_forward
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            _CHECKPOINT_FOR_DOC = "Dream-7B"
         | 
| 56 | 
            +
            _CONFIG_FOR_DOC = "DreamConfig"
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
         | 
| 60 | 
            +
            class DreamRMSNorm(nn.Module):
         | 
| 61 | 
            +
                def __init__(self, hidden_size, eps=1e-6):
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    DreamRMSNorm is equivalent to T5LayerNorm
         | 
| 64 | 
            +
                    """
         | 
| 65 | 
            +
                    super().__init__()
         | 
| 66 | 
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         | 
| 67 | 
            +
                    self.variance_epsilon = eps
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def forward(self, hidden_states):
         | 
| 70 | 
            +
                    input_dtype = hidden_states.dtype
         | 
| 71 | 
            +
                    hidden_states = hidden_states.to(torch.float32)
         | 
| 72 | 
            +
                    variance = hidden_states.pow(2).mean(-1, keepdim=True)
         | 
| 73 | 
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         | 
| 74 | 
            +
                    return self.weight * hidden_states.to(input_dtype)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def extra_repr(self):
         | 
| 77 | 
            +
                    return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
         | 
| 81 | 
            +
            class DreamRotaryEmbedding(nn.Module):
         | 
| 82 | 
            +
                def __init__(
         | 
| 83 | 
            +
                    self,
         | 
| 84 | 
            +
                    dim=None,
         | 
| 85 | 
            +
                    max_position_embeddings=2048,
         | 
| 86 | 
            +
                    base=10000,
         | 
| 87 | 
            +
                    device=None,
         | 
| 88 | 
            +
                    scaling_factor=1.0,
         | 
| 89 | 
            +
                    rope_type="default",
         | 
| 90 | 
            +
                    config: Optional[DreamConfig] = None,
         | 
| 91 | 
            +
                ):
         | 
| 92 | 
            +
                    super().__init__()
         | 
| 93 | 
            +
                    # TODO (joao): remove the `if` below, only used for BC
         | 
| 94 | 
            +
                    self.rope_kwargs = {}
         | 
| 95 | 
            +
                    if config is None:
         | 
| 96 | 
            +
                        logger.warning_once(
         | 
| 97 | 
            +
                            "`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
         | 
| 98 | 
            +
                            "`config` argument. All other arguments will be removed in v4.46"
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
                        self.rope_kwargs = {
         | 
| 101 | 
            +
                            "rope_type": rope_type,
         | 
| 102 | 
            +
                            "factor": scaling_factor,
         | 
| 103 | 
            +
                            "dim": dim,
         | 
| 104 | 
            +
                            "base": base,
         | 
| 105 | 
            +
                            "max_position_embeddings": max_position_embeddings,
         | 
| 106 | 
            +
                        }
         | 
| 107 | 
            +
                        self.rope_type = rope_type
         | 
| 108 | 
            +
                        self.max_seq_len_cached = max_position_embeddings
         | 
| 109 | 
            +
                        self.original_max_seq_len = max_position_embeddings
         | 
| 110 | 
            +
                    else:
         | 
| 111 | 
            +
                        # BC: "rope_type" was originally "type"
         | 
| 112 | 
            +
                        if config.rope_scaling is not None:
         | 
| 113 | 
            +
                            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
         | 
| 114 | 
            +
                        else:
         | 
| 115 | 
            +
                            self.rope_type = "default"
         | 
| 116 | 
            +
                        self.max_seq_len_cached = config.max_position_embeddings
         | 
| 117 | 
            +
                        self.original_max_seq_len = config.max_position_embeddings
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self.config = config
         | 
| 120 | 
            +
                    self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
         | 
| 123 | 
            +
                    self.register_buffer("inv_freq", inv_freq, persistent=False)
         | 
| 124 | 
            +
                    self.original_inv_freq = self.inv_freq
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def reset_parameters(self):
         | 
| 127 | 
            +
                    inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
         | 
| 128 | 
            +
                    self.register_buffer("inv_freq", inv_freq, persistent=False)
         | 
| 129 | 
            +
                    self.original_inv_freq = self.inv_freq
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def _dynamic_frequency_update(self, position_ids, device):
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    dynamic RoPE layers should recompute `inv_freq` in the following situations:
         | 
| 135 | 
            +
                    1 - growing beyond the cached sequence length (allow scaling)
         | 
| 136 | 
            +
                    2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    seq_len = torch.max(position_ids) + 1
         | 
| 139 | 
            +
                    if seq_len > self.max_seq_len_cached:  # growth
         | 
| 140 | 
            +
                        inv_freq, self.attention_scaling = self.rope_init_fn(
         | 
| 141 | 
            +
                            self.config, device, seq_len=seq_len, **self.rope_kwargs
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                        self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
         | 
| 144 | 
            +
                        self.max_seq_len_cached = seq_len
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
         | 
| 147 | 
            +
                        self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
         | 
| 148 | 
            +
                        self.max_seq_len_cached = self.original_max_seq_len
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                @torch.no_grad()
         | 
| 151 | 
            +
                def forward(self, x, position_ids):
         | 
| 152 | 
            +
                    if "dynamic" in self.rope_type:
         | 
| 153 | 
            +
                        self._dynamic_frequency_update(position_ids, device=x.device)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # Core RoPE block
         | 
| 156 | 
            +
                    inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
         | 
| 157 | 
            +
                    position_ids_expanded = position_ids[:, None, :].float()
         | 
| 158 | 
            +
                    # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
         | 
| 159 | 
            +
                    device_type = x.device.type
         | 
| 160 | 
            +
                    device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
         | 
| 161 | 
            +
                    with torch.autocast(device_type=device_type, enabled=False):
         | 
| 162 | 
            +
                        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
         | 
| 163 | 
            +
                        emb = torch.cat((freqs, freqs), dim=-1)
         | 
| 164 | 
            +
                        cos = emb.cos()
         | 
| 165 | 
            +
                        sin = emb.sin()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
         | 
| 168 | 
            +
                    cos = cos * self.attention_scaling
         | 
| 169 | 
            +
                    sin = sin * self.attention_scaling
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            # Copied from transformers.models.llama.modeling_llama.rotate_half
         | 
| 175 | 
            +
            def rotate_half(x):
         | 
| 176 | 
            +
                """Rotates half the hidden dims of the input."""
         | 
| 177 | 
            +
                x1 = x[..., : x.shape[-1] // 2]
         | 
| 178 | 
            +
                x2 = x[..., x.shape[-1] // 2 :]
         | 
| 179 | 
            +
                return torch.cat((-x2, x1), dim=-1)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
         | 
| 183 | 
            +
            def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
         | 
| 184 | 
            +
                """Applies Rotary Position Embedding to the query and key tensors.
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                Args:
         | 
| 187 | 
            +
                    q (`torch.Tensor`): The query tensor.
         | 
| 188 | 
            +
                    k (`torch.Tensor`): The key tensor.
         | 
| 189 | 
            +
                    cos (`torch.Tensor`): The cosine part of the rotary embedding.
         | 
| 190 | 
            +
                    sin (`torch.Tensor`): The sine part of the rotary embedding.
         | 
| 191 | 
            +
                    position_ids (`torch.Tensor`, *optional*):
         | 
| 192 | 
            +
                        Deprecated and unused.
         | 
| 193 | 
            +
                    unsqueeze_dim (`int`, *optional*, defaults to 1):
         | 
| 194 | 
            +
                        The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
         | 
| 195 | 
            +
                        sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
         | 
| 196 | 
            +
                        that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
         | 
| 197 | 
            +
                        k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
         | 
| 198 | 
            +
                        cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
         | 
| 199 | 
            +
                        the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
         | 
| 200 | 
            +
                Returns:
         | 
| 201 | 
            +
                    `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
         | 
| 202 | 
            +
                """
         | 
| 203 | 
            +
                cos = cos.unsqueeze(unsqueeze_dim)
         | 
| 204 | 
            +
                sin = sin.unsqueeze(unsqueeze_dim)
         | 
| 205 | 
            +
                q_embed = (q * cos) + (rotate_half(q) * sin)
         | 
| 206 | 
            +
                k_embed = (k * cos) + (rotate_half(k) * sin)
         | 
| 207 | 
            +
                return q_embed, k_embed
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
         | 
| 211 | 
            +
            class DreamMLP(nn.Module):
         | 
| 212 | 
            +
                def __init__(self, config):
         | 
| 213 | 
            +
                    super().__init__()
         | 
| 214 | 
            +
                    self.hidden_size = config.hidden_size
         | 
| 215 | 
            +
                    self.intermediate_size = config.intermediate_size
         | 
| 216 | 
            +
                    self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
         | 
| 217 | 
            +
                    self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
         | 
| 218 | 
            +
                    self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
         | 
| 219 | 
            +
                    self.act_fn = ACT2FN[config.hidden_act]
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def forward(self, hidden_state):
         | 
| 222 | 
            +
                    return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
         | 
| 223 | 
            +
             | 
| 224 | 
            +
             | 
| 225 | 
            +
            # Copied from transformers.models.llama.modeling_llama.repeat_kv
         | 
| 226 | 
            +
            def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
         | 
| 227 | 
            +
                """
         | 
| 228 | 
            +
                This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
         | 
| 229 | 
            +
                num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
         | 
| 230 | 
            +
                """
         | 
| 231 | 
            +
                batch, num_key_value_heads, slen, head_dim = hidden_states.shape
         | 
| 232 | 
            +
                if n_rep == 1:
         | 
| 233 | 
            +
                    return hidden_states
         | 
| 234 | 
            +
                hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
         | 
| 235 | 
            +
                return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            class DreamAttention(nn.Module):
         | 
| 239 | 
            +
                """
         | 
| 240 | 
            +
                Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
         | 
| 241 | 
            +
                and "Generating Long Sequences with Sparse Transformers".
         | 
| 242 | 
            +
                """
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
         | 
| 245 | 
            +
                    super().__init__()
         | 
| 246 | 
            +
                    self.config = config
         | 
| 247 | 
            +
                    self.layer_idx = layer_idx
         | 
| 248 | 
            +
                    if layer_idx is None:
         | 
| 249 | 
            +
                        logger.warning_once(
         | 
| 250 | 
            +
                            f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
         | 
| 251 | 
            +
                            "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
         | 
| 252 | 
            +
                            "when creating this class."
         | 
| 253 | 
            +
                        )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    self.hidden_size = config.hidden_size
         | 
| 256 | 
            +
                    self.num_heads = config.num_attention_heads
         | 
| 257 | 
            +
                    self.head_dim = self.hidden_size // self.num_heads
         | 
| 258 | 
            +
                    self.num_key_value_heads = config.num_key_value_heads
         | 
| 259 | 
            +
                    self.num_key_value_groups = self.num_heads // self.num_key_value_heads
         | 
| 260 | 
            +
                    self.max_position_embeddings = config.max_position_embeddings
         | 
| 261 | 
            +
                    self.rope_theta = config.rope_theta
         | 
| 262 | 
            +
                    self.is_causal = False
         | 
| 263 | 
            +
                    self.attention_dropout = config.attention_dropout
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    if (self.head_dim * self.num_heads) != self.hidden_size:
         | 
| 266 | 
            +
                        raise ValueError(
         | 
| 267 | 
            +
                            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
         | 
| 268 | 
            +
                            f" and `num_heads`: {self.num_heads})."
         | 
| 269 | 
            +
                        )
         | 
| 270 | 
            +
                    self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
         | 
| 271 | 
            +
                    self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
         | 
| 272 | 
            +
                    self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
         | 
| 273 | 
            +
                    self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    self.rotary_emb = DreamRotaryEmbedding(config=self.config)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def forward(
         | 
| 278 | 
            +
                    self,
         | 
| 279 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 280 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 281 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 282 | 
            +
                    past_key_value: Optional[Cache] = None,
         | 
| 283 | 
            +
                    output_attentions: bool = False,
         | 
| 284 | 
            +
                    use_cache: bool = False,
         | 
| 285 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 286 | 
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         | 
| 287 | 
            +
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         | 
| 288 | 
            +
                    bsz, q_len, _ = hidden_states.size()
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    query_states = self.q_proj(hidden_states)
         | 
| 291 | 
            +
                    key_states = self.k_proj(hidden_states)
         | 
| 292 | 
            +
                    value_states = self.v_proj(hidden_states)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
         | 
| 295 | 
            +
                    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         | 
| 296 | 
            +
                    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    if position_embeddings is None:
         | 
| 299 | 
            +
                        logger.warning_once(
         | 
| 300 | 
            +
                            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
         | 
| 301 | 
            +
                            "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
         | 
| 302 | 
            +
                            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
         | 
| 303 | 
            +
                            "removed and `position_embeddings` will be mandatory."
         | 
| 304 | 
            +
                        )
         | 
| 305 | 
            +
                        cos, sin = self.rotary_emb(value_states, position_ids)
         | 
| 306 | 
            +
                    else:
         | 
| 307 | 
            +
                        cos, sin = position_embeddings
         | 
| 308 | 
            +
                    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    if past_key_value is not None:
         | 
| 311 | 
            +
                        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
         | 
| 312 | 
            +
                        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    # repeat k/v heads if n_kv_heads < n_heads
         | 
| 315 | 
            +
                    key_states = repeat_kv(key_states, self.num_key_value_groups)
         | 
| 316 | 
            +
                    value_states = repeat_kv(value_states, self.num_key_value_groups)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
         | 
| 319 | 
            +
                    if attention_mask is not None:  # no matter the length, we just slice it
         | 
| 320 | 
            +
                        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
         | 
| 321 | 
            +
                        attn_weights = attn_weights + causal_mask
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    # upcast attention to fp32
         | 
| 324 | 
            +
                    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
         | 
| 325 | 
            +
                    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
         | 
| 326 | 
            +
                    attn_output = torch.matmul(attn_weights, value_states)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
         | 
| 329 | 
            +
                        raise ValueError(
         | 
| 330 | 
            +
                            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
         | 
| 331 | 
            +
                            f" {attn_output.size()}"
         | 
| 332 | 
            +
                        )
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    attn_output = attn_output.transpose(1, 2).contiguous()
         | 
| 335 | 
            +
                    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    attn_output = self.o_proj(attn_output)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    if not output_attentions:
         | 
| 340 | 
            +
                        attn_weights = None
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    return attn_output, attn_weights, past_key_value
         | 
| 343 | 
            +
                
         | 
| 344 | 
            +
             | 
| 345 | 
            +
            class DreamSdpaAttention(DreamAttention):
         | 
| 346 | 
            +
                """
         | 
| 347 | 
            +
                Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
         | 
| 348 | 
            +
                `DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
         | 
| 349 | 
            +
                SDPA API.
         | 
| 350 | 
            +
                """
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                # Adapted from DreamAttention.forward
         | 
| 353 | 
            +
                def forward(
         | 
| 354 | 
            +
                    self,
         | 
| 355 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 356 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 357 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 358 | 
            +
                    past_key_value: Optional[Cache] = None,
         | 
| 359 | 
            +
                    output_attentions: bool = False,
         | 
| 360 | 
            +
                    use_cache: bool = False,
         | 
| 361 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 362 | 
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         | 
| 363 | 
            +
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         | 
| 364 | 
            +
                    if output_attentions:
         | 
| 365 | 
            +
                        # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
         | 
| 366 | 
            +
                        logger.warning_once(
         | 
| 367 | 
            +
                            "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
         | 
| 368 | 
            +
                            'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
         | 
| 369 | 
            +
                        )
         | 
| 370 | 
            +
                        return super().forward(
         | 
| 371 | 
            +
                            hidden_states=hidden_states,
         | 
| 372 | 
            +
                            attention_mask=attention_mask,
         | 
| 373 | 
            +
                            position_ids=position_ids,
         | 
| 374 | 
            +
                            past_key_value=past_key_value,
         | 
| 375 | 
            +
                            output_attentions=output_attentions,
         | 
| 376 | 
            +
                            use_cache=use_cache,
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    bsz, q_len, _ = hidden_states.size()
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    query_states = self.q_proj(hidden_states)
         | 
| 382 | 
            +
                    key_states = self.k_proj(hidden_states)
         | 
| 383 | 
            +
                    value_states = self.v_proj(hidden_states)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
         | 
| 386 | 
            +
                    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         | 
| 387 | 
            +
                    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    if position_embeddings is None:
         | 
| 390 | 
            +
                        logger.warning_once(
         | 
| 391 | 
            +
                            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
         | 
| 392 | 
            +
                            "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
         | 
| 393 | 
            +
                            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
         | 
| 394 | 
            +
                            "removed and `position_embeddings` will be mandatory."
         | 
| 395 | 
            +
                        )
         | 
| 396 | 
            +
                        cos, sin = self.rotary_emb(value_states, position_ids)
         | 
| 397 | 
            +
                    else:
         | 
| 398 | 
            +
                        cos, sin = position_embeddings
         | 
| 399 | 
            +
                    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    if past_key_value is not None:
         | 
| 402 | 
            +
                        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
         | 
| 403 | 
            +
                        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    key_states = repeat_kv(key_states, self.num_key_value_groups)
         | 
| 406 | 
            +
                    value_states = repeat_kv(value_states, self.num_key_value_groups)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    # causal_mask = attention_mask
         | 
| 409 | 
            +
                    # if attention_mask is not None:  # no matter the length, we just slice it
         | 
| 410 | 
            +
                    #     causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
         | 
| 413 | 
            +
                    # Reference: https://github.com/pytorch/pytorch/issues/112577.
         | 
| 414 | 
            +
                    if query_states.device.type == "cuda" and attention_mask is not None:
         | 
| 415 | 
            +
                        query_states = query_states.contiguous()
         | 
| 416 | 
            +
                        key_states = key_states.contiguous()
         | 
| 417 | 
            +
                        value_states = value_states.contiguous()
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
         | 
| 420 | 
            +
                    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
         | 
| 421 | 
            +
                    # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
         | 
| 422 | 
            +
                    # is_causal = True if causal_mask is None and q_len > 1 else False
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    attn_output = torch.nn.functional.scaled_dot_product_attention(
         | 
| 425 | 
            +
                        query_states,
         | 
| 426 | 
            +
                        key_states,
         | 
| 427 | 
            +
                        value_states,
         | 
| 428 | 
            +
                        attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
         | 
| 429 | 
            +
                        dropout_p=self.attention_dropout if self.training else 0.0,
         | 
| 430 | 
            +
                        is_causal=False, # hard coded
         | 
| 431 | 
            +
                    )
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    attn_output = attn_output.transpose(1, 2).contiguous()
         | 
| 434 | 
            +
                    attn_output = attn_output.view(bsz, q_len, self.hidden_size)
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    attn_output = self.o_proj(attn_output)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    return attn_output, None, past_key_value
         | 
| 439 | 
            +
             | 
| 440 | 
            +
             | 
| 441 | 
            +
            class DreamDecoderLayer(nn.Module):
         | 
| 442 | 
            +
                def __init__(self, config: DreamConfig, layer_idx: int):
         | 
| 443 | 
            +
                    super().__init__()
         | 
| 444 | 
            +
                    self.hidden_size = config.hidden_size
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    if config.sliding_window and config._attn_implementation != "flash_attention_2":
         | 
| 447 | 
            +
                        logger.warning_once(
         | 
| 448 | 
            +
                            f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
         | 
| 449 | 
            +
                            "unexpected results may be encountered."
         | 
| 450 | 
            +
                        )
         | 
| 451 | 
            +
                    
         | 
| 452 | 
            +
                    # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
         | 
| 453 | 
            +
                    self.self_attn = DreamSdpaAttention(config, layer_idx)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    self.mlp = DreamMLP(config)
         | 
| 456 | 
            +
                    self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         | 
| 457 | 
            +
                    self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                def forward(
         | 
| 460 | 
            +
                    self,
         | 
| 461 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 462 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 463 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 464 | 
            +
                    past_key_value: Optional[Tuple[torch.Tensor]] = None,
         | 
| 465 | 
            +
                    output_attentions: Optional[bool] = False,
         | 
| 466 | 
            +
                    use_cache: Optional[bool] = False,
         | 
| 467 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 468 | 
            +
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         | 
| 469 | 
            +
                    **kwargs,
         | 
| 470 | 
            +
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         | 
| 471 | 
            +
                    """
         | 
| 472 | 
            +
                    Args:
         | 
| 473 | 
            +
                        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
         | 
| 474 | 
            +
                        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
         | 
| 475 | 
            +
                            `(batch, sequence_length)` where padding elements are indicated by 0.
         | 
| 476 | 
            +
                        output_attentions (`bool`, *optional*):
         | 
| 477 | 
            +
                            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
         | 
| 478 | 
            +
                            returned tensors for more detail.
         | 
| 479 | 
            +
                        use_cache (`bool`, *optional*):
         | 
| 480 | 
            +
                            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
         | 
| 481 | 
            +
                            (see `past_key_values`).
         | 
| 482 | 
            +
                        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
         | 
| 483 | 
            +
                        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
         | 
| 484 | 
            +
                            Indices depicting the position of the input sequence tokens in the sequence.
         | 
| 485 | 
            +
                        position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
         | 
| 486 | 
            +
                            Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
         | 
| 487 | 
            +
                            with `head_dim` being the embedding dimension of each attention head.
         | 
| 488 | 
            +
                        kwargs (`dict`, *optional*):
         | 
| 489 | 
            +
                            Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
         | 
| 490 | 
            +
                            into the model
         | 
| 491 | 
            +
                    """
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    residual = hidden_states
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    hidden_states = self.input_layernorm(hidden_states)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    # Self Attention
         | 
| 498 | 
            +
                    hidden_states, self_attn_weights, present_key_value = self.self_attn(
         | 
| 499 | 
            +
                        hidden_states=hidden_states,
         | 
| 500 | 
            +
                        attention_mask=attention_mask,
         | 
| 501 | 
            +
                        position_ids=position_ids,
         | 
| 502 | 
            +
                        past_key_value=past_key_value,
         | 
| 503 | 
            +
                        output_attentions=output_attentions,
         | 
| 504 | 
            +
                        use_cache=use_cache,
         | 
| 505 | 
            +
                        cache_position=cache_position,
         | 
| 506 | 
            +
                        position_embeddings=position_embeddings,
         | 
| 507 | 
            +
                    )
         | 
| 508 | 
            +
                    hidden_states = residual + hidden_states
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    # Fully Connected
         | 
| 511 | 
            +
                    residual = hidden_states
         | 
| 512 | 
            +
                    hidden_states = self.post_attention_layernorm(hidden_states)
         | 
| 513 | 
            +
                    hidden_states = self.mlp(hidden_states)
         | 
| 514 | 
            +
                    hidden_states = residual + hidden_states
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    outputs = (hidden_states,)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    if output_attentions:
         | 
| 519 | 
            +
                        outputs += (self_attn_weights,)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    if use_cache:
         | 
| 522 | 
            +
                        outputs += (present_key_value,)
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    return outputs
         | 
| 525 | 
            +
             | 
| 526 | 
            +
            class DreamPreTrainedModel(PreTrainedModel):
         | 
| 527 | 
            +
                config_class = DreamConfig
         | 
| 528 | 
            +
                base_model_prefix = "model"
         | 
| 529 | 
            +
                supports_gradient_checkpointing = True
         | 
| 530 | 
            +
                _no_split_modules = ["DreamDecoderLayer"]
         | 
| 531 | 
            +
                _skip_keys_device_placement = "past_key_values"
         | 
| 532 | 
            +
                _supports_flash_attn_2 = True
         | 
| 533 | 
            +
                _supports_sdpa = True
         | 
| 534 | 
            +
                _supports_cache_class = True
         | 
| 535 | 
            +
                _supports_quantized_cache = True
         | 
| 536 | 
            +
                _supports_static_cache = True
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                def _init_weights(self, module):
         | 
| 539 | 
            +
                    std = self.config.initializer_range
         | 
| 540 | 
            +
                    if isinstance(module, nn.Linear):
         | 
| 541 | 
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         | 
| 542 | 
            +
                        if module.bias is not None:
         | 
| 543 | 
            +
                            module.bias.data.zero_()
         | 
| 544 | 
            +
                    elif isinstance(module, nn.Embedding):
         | 
| 545 | 
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         | 
| 546 | 
            +
                        if module.padding_idx is not None:
         | 
| 547 | 
            +
                            module.weight.data[module.padding_idx].zero_()
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                @classmethod
         | 
| 550 | 
            +
                def from_pretrained(
         | 
| 551 | 
            +
                    cls,
         | 
| 552 | 
            +
                    pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
         | 
| 553 | 
            +
                    *model_args,
         | 
| 554 | 
            +
                    config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
         | 
| 555 | 
            +
                    cache_dir: Optional[Union[str, os.PathLike]] = None,
         | 
| 556 | 
            +
                    ignore_mismatched_sizes: bool = False,
         | 
| 557 | 
            +
                    force_download: bool = False,
         | 
| 558 | 
            +
                    local_files_only: bool = False,
         | 
| 559 | 
            +
                    token: Optional[Union[str, bool]] = None,
         | 
| 560 | 
            +
                    revision: str = "main",
         | 
| 561 | 
            +
                    use_safetensors: Optional[bool] = None,
         | 
| 562 | 
            +
                    weights_only: bool = True,
         | 
| 563 | 
            +
                    **kwargs,
         | 
| 564 | 
            +
                ):
         | 
| 565 | 
            +
                    _model = super().from_pretrained(
         | 
| 566 | 
            +
                        pretrained_model_name_or_path,
         | 
| 567 | 
            +
                        *model_args,
         | 
| 568 | 
            +
                        config=config,
         | 
| 569 | 
            +
                        cache_dir=cache_dir,
         | 
| 570 | 
            +
                        ignore_mismatched_sizes=ignore_mismatched_sizes,
         | 
| 571 | 
            +
                        force_download=force_download,
         | 
| 572 | 
            +
                        local_files_only=local_files_only,
         | 
| 573 | 
            +
                        token=token,
         | 
| 574 | 
            +
                        revision=revision,
         | 
| 575 | 
            +
                        use_safetensors=use_safetensors,
         | 
| 576 | 
            +
                        weights_only=weights_only,
         | 
| 577 | 
            +
                        **kwargs,
         | 
| 578 | 
            +
                    )
         | 
| 579 | 
            +
                    # NOTE(Lin): we need to override the generation config
         | 
| 580 | 
            +
                    # because the generation config loaded in `from_pretrained` 
         | 
| 581 | 
            +
                    # does not include all the attributes of DreamGenerationConfig
         | 
| 582 | 
            +
                    resume_download = kwargs.get("resume_download", None)
         | 
| 583 | 
            +
                    proxies = kwargs.get("proxies", None)
         | 
| 584 | 
            +
                    subfolder = kwargs.get("subfolder", "")
         | 
| 585 | 
            +
                    from_auto_class = kwargs.get("_from_auto", False)
         | 
| 586 | 
            +
                    from_pipeline = kwargs.get("_from_pipeline", None)
         | 
| 587 | 
            +
                    _model.generation_config = DreamGenerationConfig.from_pretrained(
         | 
| 588 | 
            +
                        pretrained_model_name_or_path,
         | 
| 589 | 
            +
                        cache_dir=cache_dir,
         | 
| 590 | 
            +
                        force_download=force_download,
         | 
| 591 | 
            +
                        resume_download=resume_download,
         | 
| 592 | 
            +
                        proxies=proxies,
         | 
| 593 | 
            +
                        local_files_only=local_files_only,
         | 
| 594 | 
            +
                        token=token,
         | 
| 595 | 
            +
                        revision=revision,
         | 
| 596 | 
            +
                        subfolder=subfolder,
         | 
| 597 | 
            +
                        _from_auto=from_auto_class,
         | 
| 598 | 
            +
                        _from_pipeline=from_pipeline,
         | 
| 599 | 
            +
                    )
         | 
| 600 | 
            +
                    return _model
         | 
| 601 | 
            +
             | 
| 602 | 
            +
            class DreamBaseModel(DreamPreTrainedModel):
         | 
| 603 | 
            +
                """
         | 
| 604 | 
            +
                Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                Args:
         | 
| 607 | 
            +
                    config: DreamConfig
         | 
| 608 | 
            +
                """
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                def __init__(self, config: DreamConfig):
         | 
| 611 | 
            +
                    super().__init__(config)
         | 
| 612 | 
            +
                    self.padding_idx = config.pad_token_id
         | 
| 613 | 
            +
                    self.vocab_size = config.vocab_size
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
         | 
| 616 | 
            +
                    self.layers = nn.ModuleList(
         | 
| 617 | 
            +
                        [DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
         | 
| 618 | 
            +
                    )
         | 
| 619 | 
            +
                    self._attn_implementation = config._attn_implementation
         | 
| 620 | 
            +
                    self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         | 
| 621 | 
            +
                    self.rotary_emb = DreamRotaryEmbedding(config=config)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    self.gradient_checkpointing = False
         | 
| 624 | 
            +
                    # Initialize weights and apply final processing
         | 
| 625 | 
            +
                    self.post_init()
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                def get_input_embeddings(self):
         | 
| 628 | 
            +
                    return self.embed_tokens
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                def set_input_embeddings(self, value):
         | 
| 631 | 
            +
                    self.embed_tokens = value
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                def forward(
         | 
| 634 | 
            +
                    self,
         | 
| 635 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 636 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 637 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 638 | 
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 639 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 640 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 641 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 642 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 643 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 644 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 645 | 
            +
                ) -> Union[Tuple, BaseModelOutput]:
         | 
| 646 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 647 | 
            +
                    output_hidden_states = (
         | 
| 648 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 649 | 
            +
                    )
         | 
| 650 | 
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    if (input_ids is None) ^ (inputs_embeds is not None):
         | 
| 655 | 
            +
                        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    if self.gradient_checkpointing and self.training:
         | 
| 658 | 
            +
                        if use_cache:
         | 
| 659 | 
            +
                            logger.warning_once(
         | 
| 660 | 
            +
                                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
         | 
| 661 | 
            +
                            )
         | 
| 662 | 
            +
                            use_cache = False
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                    if inputs_embeds is None:
         | 
| 665 | 
            +
                        inputs_embeds = self.embed_tokens(input_ids)
         | 
| 666 | 
            +
                    
         | 
| 667 | 
            +
                    if use_cache and past_key_values is None:
         | 
| 668 | 
            +
                        past_key_values = DynamicCache()
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                    if cache_position is None:
         | 
| 671 | 
            +
                        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
         | 
| 672 | 
            +
                        cache_position = torch.arange(
         | 
| 673 | 
            +
                            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
         | 
| 674 | 
            +
                        )
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    if position_ids is None:
         | 
| 677 | 
            +
                        position_ids = cache_position.unsqueeze(0)
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    hidden_states = inputs_embeds
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                    # create position embeddings to be shared across the decoder layers
         | 
| 682 | 
            +
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                    # decoder layers
         | 
| 685 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 686 | 
            +
                    all_self_attns = () if output_attentions else None
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                    for decoder_layer in self.layers:
         | 
| 689 | 
            +
                        if output_hidden_states:
         | 
| 690 | 
            +
                            all_hidden_states += (hidden_states,)
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 693 | 
            +
                            layer_outputs = self._gradient_checkpointing_func(
         | 
| 694 | 
            +
                                decoder_layer.__call__,
         | 
| 695 | 
            +
                                hidden_states,
         | 
| 696 | 
            +
                                attention_mask,
         | 
| 697 | 
            +
                                position_ids,
         | 
| 698 | 
            +
                                past_key_values,
         | 
| 699 | 
            +
                                output_attentions,
         | 
| 700 | 
            +
                                use_cache,
         | 
| 701 | 
            +
                                cache_position,
         | 
| 702 | 
            +
                                position_embeddings,
         | 
| 703 | 
            +
                            )
         | 
| 704 | 
            +
                        else:
         | 
| 705 | 
            +
                            layer_outputs = decoder_layer(
         | 
| 706 | 
            +
                                hidden_states,
         | 
| 707 | 
            +
                                attention_mask=attention_mask,
         | 
| 708 | 
            +
                                position_ids=position_ids,
         | 
| 709 | 
            +
                                past_key_value=past_key_values,
         | 
| 710 | 
            +
                                output_attentions=output_attentions,
         | 
| 711 | 
            +
                                use_cache=use_cache,
         | 
| 712 | 
            +
                                cache_position=cache_position,
         | 
| 713 | 
            +
                                position_embeddings=position_embeddings,
         | 
| 714 | 
            +
                            )
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                        hidden_states = layer_outputs[0]
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                        if output_attentions:
         | 
| 719 | 
            +
                            all_self_attns += (layer_outputs[1],)
         | 
| 720 | 
            +
             | 
| 721 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    # add hidden states from the last decoder layer
         | 
| 724 | 
            +
                    if output_hidden_states:
         | 
| 725 | 
            +
                        all_hidden_states += (hidden_states,)
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                    if not return_dict:
         | 
| 728 | 
            +
                        return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
         | 
| 729 | 
            +
                    return BaseModelOutput(
         | 
| 730 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 731 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 732 | 
            +
                        attentions=all_self_attns,
         | 
| 733 | 
            +
                    )
         | 
| 734 | 
            +
             | 
| 735 | 
            +
             | 
| 736 | 
            +
            class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
         | 
| 737 | 
            +
                _tied_weights_keys = ["lm_head.weight"]
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                def __init__(self, config):
         | 
| 740 | 
            +
                    super().__init__(config)
         | 
| 741 | 
            +
                    self.model = DreamBaseModel(config)
         | 
| 742 | 
            +
                    self.vocab_size = config.vocab_size
         | 
| 743 | 
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    # Initialize weights and apply final processing
         | 
| 746 | 
            +
                    self.post_init()
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                def reset_rope_parameters(self):
         | 
| 749 | 
            +
                    self.model.rotary_emb.reset_parameters()
         | 
| 750 | 
            +
                    for layer in self.model.layers:
         | 
| 751 | 
            +
                        layer.self_attn.rotary_emb.reset_parameters()
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                def get_input_embeddings(self):
         | 
| 754 | 
            +
                    return self.model.embed_tokens
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                def set_input_embeddings(self, value):
         | 
| 757 | 
            +
                    self.model.embed_tokens = value
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                def get_output_embeddings(self):
         | 
| 760 | 
            +
                    return self.lm_head
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                def set_output_embeddings(self, new_embeddings):
         | 
| 763 | 
            +
                    self.lm_head = new_embeddings
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                def set_decoder(self, decoder):
         | 
| 766 | 
            +
                    self.model = decoder
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                def get_decoder(self):
         | 
| 769 | 
            +
                    return self.model
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                def forward(
         | 
| 772 | 
            +
                    self,
         | 
| 773 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 774 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 775 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 776 | 
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 777 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 778 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 779 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 780 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 781 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 782 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 783 | 
            +
                    cache_position: Optional[torch.LongTensor] = None,
         | 
| 784 | 
            +
                    num_logits_to_keep: int = 0,
         | 
| 785 | 
            +
                    **loss_kwargs,
         | 
| 786 | 
            +
                ) -> Union[Tuple, MaskedLMOutput]:
         | 
| 787 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 788 | 
            +
                    output_hidden_states = (
         | 
| 789 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 790 | 
            +
                    )
         | 
| 791 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
         | 
| 794 | 
            +
                    outputs = self.model(
         | 
| 795 | 
            +
                        input_ids=input_ids,
         | 
| 796 | 
            +
                        attention_mask=attention_mask,
         | 
| 797 | 
            +
                        position_ids=position_ids,
         | 
| 798 | 
            +
                        past_key_values=past_key_values,
         | 
| 799 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 800 | 
            +
                        use_cache=use_cache,
         | 
| 801 | 
            +
                        output_attentions=output_attentions,
         | 
| 802 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 803 | 
            +
                        return_dict=return_dict,
         | 
| 804 | 
            +
                        cache_position=cache_position,
         | 
| 805 | 
            +
                    )
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                    hidden_states = outputs[0]
         | 
| 808 | 
            +
                    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
         | 
| 809 | 
            +
                    logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
         | 
| 810 | 
            +
             | 
| 811 | 
            +
                    loss = None
         | 
| 812 | 
            +
                    if labels is not None:
         | 
| 813 | 
            +
                        loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                    if not return_dict:
         | 
| 816 | 
            +
                        output = (logits,) + outputs[1:]
         | 
| 817 | 
            +
                        return (loss,) + output if loss is not None else output
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                    return MaskedLMOutput(
         | 
| 820 | 
            +
                        loss=loss,
         | 
| 821 | 
            +
                        logits=logits,
         | 
| 822 | 
            +
                        hidden_states=outputs.hidden_states,
         | 
| 823 | 
            +
                        attentions=outputs.attentions,
         | 
| 824 | 
            +
                    )
         | 
    	
        special_tokens_map.json
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "additional_special_tokens": [
         | 
| 3 | 
            +
                "<|beginoftext|>",
         | 
| 4 | 
            +
                "<|mask|>",
         | 
| 5 | 
            +
                "<|im_end|>"
         | 
| 6 | 
            +
              ],
         | 
| 7 | 
            +
              "bos_token": {
         | 
| 8 | 
            +
                "content": "<|beginoftext|>",
         | 
| 9 | 
            +
                "lstrip": false,
         | 
| 10 | 
            +
                "normalized": false,
         | 
| 11 | 
            +
                "rstrip": false,
         | 
| 12 | 
            +
                "single_word": false
         | 
| 13 | 
            +
              },
         | 
| 14 | 
            +
              "eos_token": {
         | 
| 15 | 
            +
                "content": "<|endoftext|>",
         | 
| 16 | 
            +
                "lstrip": false,
         | 
| 17 | 
            +
                "normalized": false,
         | 
| 18 | 
            +
                "rstrip": false,
         | 
| 19 | 
            +
                "single_word": false
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              "mask_token": {
         | 
| 22 | 
            +
                "content": "<|mask|>",
         | 
| 23 | 
            +
                "lstrip": false,
         | 
| 24 | 
            +
                "normalized": false,
         | 
| 25 | 
            +
                "rstrip": false,
         | 
| 26 | 
            +
                "single_word": false
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "pad_token": {
         | 
| 29 | 
            +
                "content": "<|dlm_pad|>",
         | 
| 30 | 
            +
                "lstrip": false,
         | 
| 31 | 
            +
                "normalized": false,
         | 
| 32 | 
            +
                "rstrip": false,
         | 
| 33 | 
            +
                "single_word": false
         | 
| 34 | 
            +
              }
         | 
| 35 | 
            +
            }
         | 
    	
        tokenization_dream.py
    ADDED
    
    | @@ -0,0 +1,340 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2024 The Dream team, HKUNLP Group and The HuggingFace Inc. team. All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This code is based on Qwen's implementations in this library.
         | 
| 5 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            +
            # You may obtain a copy of the License at
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            +
            # limitations under the License.
         | 
| 16 | 
            +
            """Tokenization classes for Dream."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import json
         | 
| 19 | 
            +
            import os
         | 
| 20 | 
            +
            import unicodedata
         | 
| 21 | 
            +
            from functools import lru_cache
         | 
| 22 | 
            +
            from typing import Optional, Tuple
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            import regex as re
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
         | 
| 27 | 
            +
            from transformers.utils import logging
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            VOCAB_FILES_NAMES = {
         | 
| 33 | 
            +
                "vocab_file": "vocab.json",
         | 
| 34 | 
            +
                "merges_file": "merges.txt",
         | 
| 35 | 
            +
            }
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768}
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            @lru_cache()
         | 
| 44 | 
            +
            # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
         | 
| 45 | 
            +
            def bytes_to_unicode():
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
         | 
| 48 | 
            +
                characters the bpe code barfs on.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
         | 
| 51 | 
            +
                if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
         | 
| 52 | 
            +
                decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
         | 
| 53 | 
            +
                tables between utf-8 bytes and unicode strings.
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                bs = (
         | 
| 56 | 
            +
                    list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
         | 
| 57 | 
            +
                )
         | 
| 58 | 
            +
                cs = bs[:]
         | 
| 59 | 
            +
                n = 0
         | 
| 60 | 
            +
                for b in range(2**8):
         | 
| 61 | 
            +
                    if b not in bs:
         | 
| 62 | 
            +
                        bs.append(b)
         | 
| 63 | 
            +
                        cs.append(2**8 + n)
         | 
| 64 | 
            +
                        n += 1
         | 
| 65 | 
            +
                cs = [chr(n) for n in cs]
         | 
| 66 | 
            +
                return dict(zip(bs, cs))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
         | 
| 70 | 
            +
            def get_pairs(word):
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                Return set of symbol pairs in a word.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                Word is represented as tuple of symbols (symbols being variable-length strings).
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                pairs = set()
         | 
| 77 | 
            +
                prev_char = word[0]
         | 
| 78 | 
            +
                for char in word[1:]:
         | 
| 79 | 
            +
                    pairs.add((prev_char, char))
         | 
| 80 | 
            +
                    prev_char = char
         | 
| 81 | 
            +
                return pairs
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            class DreamTokenizer(PreTrainedTokenizer):
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
         | 
| 89 | 
            +
                be encoded differently whether it is at the beginning of the sentence (without space) or not:
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                ```python
         | 
| 92 | 
            +
                >>> from transformers import AutoTokenizer
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                >>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True)
         | 
| 95 | 
            +
                >>> tokenizer("Hello world")["input_ids"]
         | 
| 96 | 
            +
                [9707, 1879]
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                >>> tokenizer(" Hello world")["input_ids"]
         | 
| 99 | 
            +
                [21927, 1879]
         | 
| 100 | 
            +
                ```
         | 
| 101 | 
            +
                This is expected.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
         | 
| 106 | 
            +
                this superclass for more information regarding those methods.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Args:
         | 
| 109 | 
            +
                    vocab_file (`str`):
         | 
| 110 | 
            +
                        Path to the vocabulary file.
         | 
| 111 | 
            +
                    merges_file (`str`):
         | 
| 112 | 
            +
                        Path to the merges file.
         | 
| 113 | 
            +
                    errors (`str`, *optional*, defaults to `"replace"`):
         | 
| 114 | 
            +
                        Paradigm to follow when decoding bytes to UTF-8. See
         | 
| 115 | 
            +
                        [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
         | 
| 116 | 
            +
                    unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
         | 
| 117 | 
            +
                        The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
         | 
| 118 | 
            +
                        token instead.
         | 
| 119 | 
            +
                    bos_token (`str`, *optional*):
         | 
| 120 | 
            +
                        The beginning of sequence token. Not applicable for this tokenizer.
         | 
| 121 | 
            +
                    eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
         | 
| 122 | 
            +
                        The end of sequence token.
         | 
| 123 | 
            +
                    pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
         | 
| 124 | 
            +
                        The token used for padding, for example when batching sequences of different lengths.
         | 
| 125 | 
            +
                    clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
         | 
| 126 | 
            +
                        Whether or not the model should cleanup the spaces that were added when splitting the input text during the
         | 
| 127 | 
            +
                        tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
         | 
| 128 | 
            +
                    split_special_tokens (`bool`, *optional*, defaults to `False`):
         | 
| 129 | 
            +
                        Whether or not the special tokens should be split during the tokenization process. The default behavior is
         | 
| 130 | 
            +
                        to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
         | 
| 131 | 
            +
                        ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
         | 
| 132 | 
            +
                        '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
         | 
| 133 | 
            +
                """
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                vocab_files_names = VOCAB_FILES_NAMES
         | 
| 136 | 
            +
                model_input_names = ["input_ids", "attention_mask"]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def __init__(
         | 
| 139 | 
            +
                    self,
         | 
| 140 | 
            +
                    vocab_file,
         | 
| 141 | 
            +
                    merges_file,
         | 
| 142 | 
            +
                    errors="replace",
         | 
| 143 | 
            +
                    unk_token="<|endoftext|>",
         | 
| 144 | 
            +
                    bos_token=None,
         | 
| 145 | 
            +
                    eos_token="<|endoftext|>",
         | 
| 146 | 
            +
                    pad_token="<|endoftext|>",
         | 
| 147 | 
            +
                    clean_up_tokenization_spaces=False,
         | 
| 148 | 
            +
                    split_special_tokens=False,
         | 
| 149 | 
            +
                    **kwargs,
         | 
| 150 | 
            +
                ):
         | 
| 151 | 
            +
                    # Dream vocab does not contain control tokens; added tokens need to be special
         | 
| 152 | 
            +
                    bos_token = (
         | 
| 153 | 
            +
                        AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
         | 
| 154 | 
            +
                        if isinstance(bos_token, str)
         | 
| 155 | 
            +
                        else bos_token
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                    eos_token = (
         | 
| 158 | 
            +
                        AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
         | 
| 159 | 
            +
                        if isinstance(eos_token, str)
         | 
| 160 | 
            +
                        else eos_token
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
                    unk_token = (
         | 
| 163 | 
            +
                        AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
         | 
| 164 | 
            +
                        if isinstance(unk_token, str)
         | 
| 165 | 
            +
                        else unk_token
         | 
| 166 | 
            +
                    )
         | 
| 167 | 
            +
                    pad_token = (
         | 
| 168 | 
            +
                        AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
         | 
| 169 | 
            +
                        if isinstance(pad_token, str)
         | 
| 170 | 
            +
                        else pad_token
         | 
| 171 | 
            +
                    )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    with open(vocab_file, encoding="utf-8") as vocab_handle:
         | 
| 174 | 
            +
                        self.encoder = json.load(vocab_handle)
         | 
| 175 | 
            +
                    self.decoder = {v: k for k, v in self.encoder.items()}
         | 
| 176 | 
            +
                    self.errors = errors  # how to handle errors in decoding
         | 
| 177 | 
            +
                    self.byte_encoder = bytes_to_unicode()
         | 
| 178 | 
            +
                    self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
         | 
| 179 | 
            +
                    bpe_merges = []
         | 
| 180 | 
            +
                    with open(merges_file, encoding="utf-8") as merges_handle:
         | 
| 181 | 
            +
                        for i, line in enumerate(merges_handle):
         | 
| 182 | 
            +
                            line = line.strip()
         | 
| 183 | 
            +
                            if (i == 0 and line.startswith("#version:")) or not line:
         | 
| 184 | 
            +
                                continue
         | 
| 185 | 
            +
                            bpe_merges.append(tuple(line.split()))
         | 
| 186 | 
            +
                    self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
         | 
| 187 | 
            +
                    # NOTE: the cache can grow without bound and will get really large for long running processes
         | 
| 188 | 
            +
                    # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
         | 
| 189 | 
            +
                    # not a memory leak but appears as one.
         | 
| 190 | 
            +
                    # GPT2Tokenizer has the same problem, so let's be consistent.
         | 
| 191 | 
            +
                    self.cache = {}
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    self.pat = re.compile(PRETOKENIZE_REGEX)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if kwargs.get("add_prefix_space", False):
         | 
| 196 | 
            +
                        logger.warning_once(
         | 
| 197 | 
            +
                            f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
         | 
| 198 | 
            +
                        )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    super().__init__(
         | 
| 201 | 
            +
                        errors=errors,
         | 
| 202 | 
            +
                        bos_token=bos_token,
         | 
| 203 | 
            +
                        eos_token=eos_token,
         | 
| 204 | 
            +
                        pad_token=pad_token,
         | 
| 205 | 
            +
                        unk_token=unk_token,
         | 
| 206 | 
            +
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
         | 
| 207 | 
            +
                        split_special_tokens=split_special_tokens,
         | 
| 208 | 
            +
                        **kwargs,
         | 
| 209 | 
            +
                    )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                @property
         | 
| 212 | 
            +
                def vocab_size(self) -> int:
         | 
| 213 | 
            +
                    return len(self.encoder)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
         | 
| 216 | 
            +
                def get_vocab(self):
         | 
| 217 | 
            +
                    return dict(self.encoder, **self.added_tokens_encoder)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
         | 
| 220 | 
            +
                def bpe(self, token):
         | 
| 221 | 
            +
                    if token in self.cache:
         | 
| 222 | 
            +
                        return self.cache[token]
         | 
| 223 | 
            +
                    word = tuple(token)
         | 
| 224 | 
            +
                    pairs = get_pairs(word)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    if not pairs:
         | 
| 227 | 
            +
                        return token
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    while True:
         | 
| 230 | 
            +
                        bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
         | 
| 231 | 
            +
                        if bigram not in self.bpe_ranks:
         | 
| 232 | 
            +
                            break
         | 
| 233 | 
            +
                        first, second = bigram
         | 
| 234 | 
            +
                        new_word = []
         | 
| 235 | 
            +
                        i = 0
         | 
| 236 | 
            +
                        while i < len(word):
         | 
| 237 | 
            +
                            try:
         | 
| 238 | 
            +
                                j = word.index(first, i)
         | 
| 239 | 
            +
                            except ValueError:
         | 
| 240 | 
            +
                                new_word.extend(word[i:])
         | 
| 241 | 
            +
                                break
         | 
| 242 | 
            +
                            else:
         | 
| 243 | 
            +
                                new_word.extend(word[i:j])
         | 
| 244 | 
            +
                                i = j
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
         | 
| 247 | 
            +
                                new_word.append(first + second)
         | 
| 248 | 
            +
                                i += 2
         | 
| 249 | 
            +
                            else:
         | 
| 250 | 
            +
                                new_word.append(word[i])
         | 
| 251 | 
            +
                                i += 1
         | 
| 252 | 
            +
                        new_word = tuple(new_word)
         | 
| 253 | 
            +
                        word = new_word
         | 
| 254 | 
            +
                        if len(word) == 1:
         | 
| 255 | 
            +
                            break
         | 
| 256 | 
            +
                        else:
         | 
| 257 | 
            +
                            pairs = get_pairs(word)
         | 
| 258 | 
            +
                    word = " ".join(word)
         | 
| 259 | 
            +
                    self.cache[token] = word
         | 
| 260 | 
            +
                    return word
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
         | 
| 263 | 
            +
                def _tokenize(self, text):
         | 
| 264 | 
            +
                    """Tokenize a string."""
         | 
| 265 | 
            +
                    bpe_tokens = []
         | 
| 266 | 
            +
                    for token in re.findall(self.pat, text):
         | 
| 267 | 
            +
                        token = "".join(
         | 
| 268 | 
            +
                            self.byte_encoder[b] for b in token.encode("utf-8")
         | 
| 269 | 
            +
                        )  # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
         | 
| 270 | 
            +
                        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
         | 
| 271 | 
            +
                    return bpe_tokens
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
         | 
| 274 | 
            +
                def _convert_token_to_id(self, token):
         | 
| 275 | 
            +
                    """Converts a token (str) in an id using the vocab."""
         | 
| 276 | 
            +
                    return self.encoder.get(token, self.encoder.get(self.unk_token))
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
         | 
| 279 | 
            +
                def _convert_id_to_token(self, index):
         | 
| 280 | 
            +
                    """Converts an index (integer) in a token (str) using the vocab."""
         | 
| 281 | 
            +
                    return self.decoder.get(index)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
         | 
| 284 | 
            +
                def convert_tokens_to_string(self, tokens):
         | 
| 285 | 
            +
                    """Converts a sequence of tokens (string) in a single string."""
         | 
| 286 | 
            +
                    text = "".join(tokens)
         | 
| 287 | 
            +
                    text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
         | 
| 288 | 
            +
                    return text
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                def decode(
         | 
| 291 | 
            +
                    self,
         | 
| 292 | 
            +
                    token_ids,
         | 
| 293 | 
            +
                    skip_special_tokens: bool = False,
         | 
| 294 | 
            +
                    clean_up_tokenization_spaces: Optional[bool] = False,
         | 
| 295 | 
            +
                    spaces_between_special_tokens: bool = False,
         | 
| 296 | 
            +
                    **kwargs,
         | 
| 297 | 
            +
                ) -> str:
         | 
| 298 | 
            +
                    # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
         | 
| 299 | 
            +
                    # and cannot be configured elsewhere, but it should default to False for DreamTokenizer
         | 
| 300 | 
            +
                    return super().decode(
         | 
| 301 | 
            +
                        token_ids,
         | 
| 302 | 
            +
                        skip_special_tokens=skip_special_tokens,
         | 
| 303 | 
            +
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
         | 
| 304 | 
            +
                        spaces_between_special_tokens=spaces_between_special_tokens,
         | 
| 305 | 
            +
                        **kwargs,
         | 
| 306 | 
            +
                    )
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
         | 
| 309 | 
            +
                def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
         | 
| 310 | 
            +
                    if not os.path.isdir(save_directory):
         | 
| 311 | 
            +
                        logger.error(f"Vocabulary path ({save_directory}) should be a directory")
         | 
| 312 | 
            +
                        return
         | 
| 313 | 
            +
                    vocab_file = os.path.join(
         | 
| 314 | 
            +
                        save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
         | 
| 315 | 
            +
                    )
         | 
| 316 | 
            +
                    merge_file = os.path.join(
         | 
| 317 | 
            +
                        save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
         | 
| 318 | 
            +
                    )
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    with open(vocab_file, "w", encoding="utf-8") as f:
         | 
| 321 | 
            +
                        f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    index = 0
         | 
| 324 | 
            +
                    with open(merge_file, "w", encoding="utf-8") as writer:
         | 
| 325 | 
            +
                        writer.write("#version: 0.2\n")
         | 
| 326 | 
            +
                        for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
         | 
| 327 | 
            +
                            if index != token_index:
         | 
| 328 | 
            +
                                logger.warning(
         | 
| 329 | 
            +
                                    f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
         | 
| 330 | 
            +
                                    " Please check that the tokenizer is not corrupted!"
         | 
| 331 | 
            +
                                )
         | 
| 332 | 
            +
                                index = token_index
         | 
| 333 | 
            +
                            writer.write(" ".join(bpe_tokens) + "\n")
         | 
| 334 | 
            +
                            index += 1
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    return vocab_file, merge_file
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                def prepare_for_tokenization(self, text, **kwargs):
         | 
| 339 | 
            +
                    text = unicodedata.normalize("NFC", text)
         | 
| 340 | 
            +
                    return (text, kwargs)
         | 
    	
        tokenizer_config.json
    ADDED
    
    | @@ -0,0 +1,228 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "add_bos_token": false,
         | 
| 3 | 
            +
              "add_prefix_space": false,
         | 
| 4 | 
            +
              "added_tokens_decoder": {
         | 
| 5 | 
            +
                "151643": {
         | 
| 6 | 
            +
                  "content": "<|endoftext|>",
         | 
| 7 | 
            +
                  "lstrip": false,
         | 
| 8 | 
            +
                  "normalized": false,
         | 
| 9 | 
            +
                  "rstrip": false,
         | 
| 10 | 
            +
                  "single_word": false,
         | 
| 11 | 
            +
                  "special": true
         | 
| 12 | 
            +
                },
         | 
| 13 | 
            +
                "151644": {
         | 
| 14 | 
            +
                  "content": "<|im_start|>",
         | 
| 15 | 
            +
                  "lstrip": false,
         | 
| 16 | 
            +
                  "normalized": false,
         | 
| 17 | 
            +
                  "rstrip": false,
         | 
| 18 | 
            +
                  "single_word": false,
         | 
| 19 | 
            +
                  "special": true
         | 
| 20 | 
            +
                },
         | 
| 21 | 
            +
                "151645": {
         | 
| 22 | 
            +
                  "content": "<|im_end|>",
         | 
| 23 | 
            +
                  "lstrip": false,
         | 
| 24 | 
            +
                  "normalized": false,
         | 
| 25 | 
            +
                  "rstrip": false,
         | 
| 26 | 
            +
                  "single_word": false,
         | 
| 27 | 
            +
                  "special": true
         | 
| 28 | 
            +
                },
         | 
| 29 | 
            +
                "151646": {
         | 
| 30 | 
            +
                  "content": "<|object_ref_start|>",
         | 
| 31 | 
            +
                  "lstrip": false,
         | 
| 32 | 
            +
                  "normalized": false,
         | 
| 33 | 
            +
                  "rstrip": false,
         | 
| 34 | 
            +
                  "single_word": false,
         | 
| 35 | 
            +
                  "special": true
         | 
| 36 | 
            +
                },
         | 
| 37 | 
            +
                "151647": {
         | 
| 38 | 
            +
                  "content": "<|object_ref_end|>",
         | 
| 39 | 
            +
                  "lstrip": false,
         | 
| 40 | 
            +
                  "normalized": false,
         | 
| 41 | 
            +
                  "rstrip": false,
         | 
| 42 | 
            +
                  "single_word": false,
         | 
| 43 | 
            +
                  "special": true
         | 
| 44 | 
            +
                },
         | 
| 45 | 
            +
                "151648": {
         | 
| 46 | 
            +
                  "content": "<|box_start|>",
         | 
| 47 | 
            +
                  "lstrip": false,
         | 
| 48 | 
            +
                  "normalized": false,
         | 
| 49 | 
            +
                  "rstrip": false,
         | 
| 50 | 
            +
                  "single_word": false,
         | 
| 51 | 
            +
                  "special": true
         | 
| 52 | 
            +
                },
         | 
| 53 | 
            +
                "151649": {
         | 
| 54 | 
            +
                  "content": "<|box_end|>",
         | 
| 55 | 
            +
                  "lstrip": false,
         | 
| 56 | 
            +
                  "normalized": false,
         | 
| 57 | 
            +
                  "rstrip": false,
         | 
| 58 | 
            +
                  "single_word": false,
         | 
| 59 | 
            +
                  "special": true
         | 
| 60 | 
            +
                },
         | 
| 61 | 
            +
                "151650": {
         | 
| 62 | 
            +
                  "content": "<|quad_start|>",
         | 
| 63 | 
            +
                  "lstrip": false,
         | 
| 64 | 
            +
                  "normalized": false,
         | 
| 65 | 
            +
                  "rstrip": false,
         | 
| 66 | 
            +
                  "single_word": false,
         | 
| 67 | 
            +
                  "special": true
         | 
| 68 | 
            +
                },
         | 
| 69 | 
            +
                "151651": {
         | 
| 70 | 
            +
                  "content": "<|quad_end|>",
         | 
| 71 | 
            +
                  "lstrip": false,
         | 
| 72 | 
            +
                  "normalized": false,
         | 
| 73 | 
            +
                  "rstrip": false,
         | 
| 74 | 
            +
                  "single_word": false,
         | 
| 75 | 
            +
                  "special": true
         | 
| 76 | 
            +
                },
         | 
| 77 | 
            +
                "151652": {
         | 
| 78 | 
            +
                  "content": "<|vision_start|>",
         | 
| 79 | 
            +
                  "lstrip": false,
         | 
| 80 | 
            +
                  "normalized": false,
         | 
| 81 | 
            +
                  "rstrip": false,
         | 
| 82 | 
            +
                  "single_word": false,
         | 
| 83 | 
            +
                  "special": true
         | 
| 84 | 
            +
                },
         | 
| 85 | 
            +
                "151653": {
         | 
| 86 | 
            +
                  "content": "<|vision_end|>",
         | 
| 87 | 
            +
                  "lstrip": false,
         | 
| 88 | 
            +
                  "normalized": false,
         | 
| 89 | 
            +
                  "rstrip": false,
         | 
| 90 | 
            +
                  "single_word": false,
         | 
| 91 | 
            +
                  "special": true
         | 
| 92 | 
            +
                },
         | 
| 93 | 
            +
                "151654": {
         | 
| 94 | 
            +
                  "content": "<|vision_pad|>",
         | 
| 95 | 
            +
                  "lstrip": false,
         | 
| 96 | 
            +
                  "normalized": false,
         | 
| 97 | 
            +
                  "rstrip": false,
         | 
| 98 | 
            +
                  "single_word": false,
         | 
| 99 | 
            +
                  "special": true
         | 
| 100 | 
            +
                },
         | 
| 101 | 
            +
                "151655": {
         | 
| 102 | 
            +
                  "content": "<|image_pad|>",
         | 
| 103 | 
            +
                  "lstrip": false,
         | 
| 104 | 
            +
                  "normalized": false,
         | 
| 105 | 
            +
                  "rstrip": false,
         | 
| 106 | 
            +
                  "single_word": false,
         | 
| 107 | 
            +
                  "special": true
         | 
| 108 | 
            +
                },
         | 
| 109 | 
            +
                "151656": {
         | 
| 110 | 
            +
                  "content": "<|video_pad|>",
         | 
| 111 | 
            +
                  "lstrip": false,
         | 
| 112 | 
            +
                  "normalized": false,
         | 
| 113 | 
            +
                  "rstrip": false,
         | 
| 114 | 
            +
                  "single_word": false,
         | 
| 115 | 
            +
                  "special": true
         | 
| 116 | 
            +
                },
         | 
| 117 | 
            +
                "151657": {
         | 
| 118 | 
            +
                  "content": "<tool_call>",
         | 
| 119 | 
            +
                  "lstrip": false,
         | 
| 120 | 
            +
                  "normalized": false,
         | 
| 121 | 
            +
                  "rstrip": false,
         | 
| 122 | 
            +
                  "single_word": false,
         | 
| 123 | 
            +
                  "special": false
         | 
| 124 | 
            +
                },
         | 
| 125 | 
            +
                "151658": {
         | 
| 126 | 
            +
                  "content": "</tool_call>",
         | 
| 127 | 
            +
                  "lstrip": false,
         | 
| 128 | 
            +
                  "normalized": false,
         | 
| 129 | 
            +
                  "rstrip": false,
         | 
| 130 | 
            +
                  "single_word": false,
         | 
| 131 | 
            +
                  "special": false
         | 
| 132 | 
            +
                },
         | 
| 133 | 
            +
                "151659": {
         | 
| 134 | 
            +
                  "content": "<|fim_prefix|>",
         | 
| 135 | 
            +
                  "lstrip": false,
         | 
| 136 | 
            +
                  "normalized": false,
         | 
| 137 | 
            +
                  "rstrip": false,
         | 
| 138 | 
            +
                  "single_word": false,
         | 
| 139 | 
            +
                  "special": false
         | 
| 140 | 
            +
                },
         | 
| 141 | 
            +
                "151660": {
         | 
| 142 | 
            +
                  "content": "<|fim_middle|>",
         | 
| 143 | 
            +
                  "lstrip": false,
         | 
| 144 | 
            +
                  "normalized": false,
         | 
| 145 | 
            +
                  "rstrip": false,
         | 
| 146 | 
            +
                  "single_word": false,
         | 
| 147 | 
            +
                  "special": false
         | 
| 148 | 
            +
                },
         | 
| 149 | 
            +
                "151661": {
         | 
| 150 | 
            +
                  "content": "<|fim_suffix|>",
         | 
| 151 | 
            +
                  "lstrip": false,
         | 
| 152 | 
            +
                  "normalized": false,
         | 
| 153 | 
            +
                  "rstrip": false,
         | 
| 154 | 
            +
                  "single_word": false,
         | 
| 155 | 
            +
                  "special": false
         | 
| 156 | 
            +
                },
         | 
| 157 | 
            +
                "151662": {
         | 
| 158 | 
            +
                  "content": "<|fim_pad|>",
         | 
| 159 | 
            +
                  "lstrip": false,
         | 
| 160 | 
            +
                  "normalized": false,
         | 
| 161 | 
            +
                  "rstrip": false,
         | 
| 162 | 
            +
                  "single_word": false,
         | 
| 163 | 
            +
                  "special": false
         | 
| 164 | 
            +
                },
         | 
| 165 | 
            +
                "151663": {
         | 
| 166 | 
            +
                  "content": "<|repo_name|>",
         | 
| 167 | 
            +
                  "lstrip": false,
         | 
| 168 | 
            +
                  "normalized": false,
         | 
| 169 | 
            +
                  "rstrip": false,
         | 
| 170 | 
            +
                  "single_word": false,
         | 
| 171 | 
            +
                  "special": false
         | 
| 172 | 
            +
                },
         | 
| 173 | 
            +
                "151664": {
         | 
| 174 | 
            +
                  "content": "<|file_sep|>",
         | 
| 175 | 
            +
                  "lstrip": false,
         | 
| 176 | 
            +
                  "normalized": false,
         | 
| 177 | 
            +
                  "rstrip": false,
         | 
| 178 | 
            +
                  "single_word": false,
         | 
| 179 | 
            +
                  "special": false
         | 
| 180 | 
            +
                },
         | 
| 181 | 
            +
                "151665": {
         | 
| 182 | 
            +
                  "content": "<|beginoftext|>",
         | 
| 183 | 
            +
                  "lstrip": false,
         | 
| 184 | 
            +
                  "normalized": false,
         | 
| 185 | 
            +
                  "rstrip": false,
         | 
| 186 | 
            +
                  "single_word": false,
         | 
| 187 | 
            +
                  "special": true
         | 
| 188 | 
            +
                },
         | 
| 189 | 
            +
                "151666": {
         | 
| 190 | 
            +
                  "content": "<|mask|>",
         | 
| 191 | 
            +
                  "lstrip": false,
         | 
| 192 | 
            +
                  "normalized": false,
         | 
| 193 | 
            +
                  "rstrip": false,
         | 
| 194 | 
            +
                  "single_word": false,
         | 
| 195 | 
            +
                  "special": true
         | 
| 196 | 
            +
                },
         | 
| 197 | 
            +
                "151667": {
         | 
| 198 | 
            +
                  "content": "<|dlm_pad|>",
         | 
| 199 | 
            +
                  "lstrip": false,
         | 
| 200 | 
            +
                  "normalized": false,
         | 
| 201 | 
            +
                  "rstrip": false,
         | 
| 202 | 
            +
                  "single_word": false,
         | 
| 203 | 
            +
                  "special": true
         | 
| 204 | 
            +
                }
         | 
| 205 | 
            +
              },
         | 
| 206 | 
            +
              "additional_special_tokens": [
         | 
| 207 | 
            +
                "<|beginoftext|>",
         | 
| 208 | 
            +
                "<|mask|>",
         | 
| 209 | 
            +
                "<|im_end|>"
         | 
| 210 | 
            +
              ],
         | 
| 211 | 
            +
              "auto_map": {
         | 
| 212 | 
            +
                "AutoTokenizer": [
         | 
| 213 | 
            +
                  "tokenization_dream.DreamTokenizer",
         | 
| 214 | 
            +
                  null
         | 
| 215 | 
            +
                ]
         | 
| 216 | 
            +
              },
         | 
| 217 | 
            +
              "bos_token": "<|beginoftext|>",
         | 
| 218 | 
            +
              "clean_up_tokenization_spaces": false,
         | 
| 219 | 
            +
              "eos_token": "<|endoftext|>",
         | 
| 220 | 
            +
              "errors": "replace",
         | 
| 221 | 
            +
              "extra_special_tokens": {},
         | 
| 222 | 
            +
              "mask_token": "<|mask|>",
         | 
| 223 | 
            +
              "model_max_length": 131072,
         | 
| 224 | 
            +
              "pad_token": "<|dlm_pad|>",
         | 
| 225 | 
            +
              "split_special_tokens": false,
         | 
| 226 | 
            +
              "tokenizer_class": "DreamTokenizer",
         | 
| 227 | 
            +
              "unk_token": null
         | 
| 228 | 
            +
            }
         | 
    	
        train_results.json
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "total_flos": 0.0,
         | 
| 3 | 
            +
                "train_loss": 1756184.5472393532,
         | 
| 4 | 
            +
                "train_runtime": 149594.4727,
         | 
| 5 | 
            +
                "train_samples": 20822,
         | 
| 6 | 
            +
                "train_samples_per_second": 0.139,
         | 
| 7 | 
            +
                "train_steps_per_second": 0.035
         | 
| 8 | 
            +
            }
         | 
    	
        trainer_state.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        training_args.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:3c79dc2e0f6cfbb5b3469591be012b4e3711cd1f4159a4e2b033180fa8b0fbf2
         | 
| 3 | 
            +
            size 8376
         | 
    	
        vocab.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  |