Chengyue Wu commited on
Commit
19930b4
·
1 Parent(s): 5040112

update with LFS support

Browse files
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ *.gif filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,148 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - Qwen/Qwen2.5-7B-Instruct
7
+ ---
8
+
9
+ # Fast-dLLM v2 (7B) — Efficient Block-Diffusion LLM
10
+
11
+ ## 📖 Introduction
12
+
13
+ Autoregressive (AR) large language models (LLMs) have achieved remarkable performance across a wide range of natural language tasks, yet their **inherent sequential decoding limits inference efficiency**.
14
+
15
+ We present **Fast-dLLM v2** — a carefully designed **block diffusion language model (dLLM)** that efficiently adapts a pretrained AR model (**Qwen2.5-7B-Instruct**) into a diffusion-style decoder for **parallel text generation**.
16
+
17
+ ### ✨ Key Innovations
18
+ - **Block Diffusion Mechanism + Complementary Attention Mask**
19
+ Enables **blockwise bidirectional context modeling** without sacrificing AR objectives.
20
+ - **Hierarchical Caching**
21
+ - **Block-level cache**: Stores historical context representations across blocks.
22
+ - **Sub-block cache**: Parallel decoding within partially generated blocks.
23
+ - **Token Shift Mechanism**
24
+ Retains autoregressive characteristics while supporting bidirectional context within blocks.
25
+ - **Parallel Decoding Pipeline**
26
+ Achieves up to **2.5× speedup** over standard AR decoding **without compromising quality**.
27
+
28
+ > 🚀 Fast-dLLM v2 uses **only ~1B tokens** for fine-tuning — a **500× reduction** vs. full-attention diffusion LLMs (Dream: 580B tokens) — while **matching or surpassing AR baselines** in accuracy.
29
+
30
+ ![Generation Process](assets/visualization_animation.gif)
31
+
32
+ ---
33
+
34
+ ## 🛠 Model Overview
35
+ - **Type**: Block Diffusion Language Model (dLLM)
36
+ - **Base Model**: `Qwen/Qwen2.5-7B-Instruct`
37
+ - **Architecture**: Transformer w/ RoPE, SwiGLU activation, RMSNorm, Attention QKV bias
38
+ - **Params**: ~7B
39
+ - **Layers**: 28
40
+ - **Attention Heads**: 28 (Q), 4 (KV, GQA)
41
+ - **Block Diffusion Size**: 32 tokens
42
+ - **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching (block-level & sub-block)**
43
+
44
+ ---
45
+
46
+ ## 📦 Installation
47
+ You will need `transformers`, `torch`, and our **custom generation function**:
48
+
49
+ ```bash
50
+ pip install transformers torch numpy
51
+ ```
52
+
53
+ ---
54
+
55
+ ## 🚀 Quickstart
56
+
57
+ ```python
58
+ from transformers import AutoModelForCausalLM, AutoTokenizer
59
+
60
+ model_name = "Efficient-Large-Model/Fast_dLLM_7B"
61
+
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ model_name,
64
+ torch_dtype="auto",
65
+ device_map="auto",
66
+ trust_remote_code=True
67
+ )
68
+
69
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
70
+
71
+ prompt = "Give me a short introduction to large language model."
72
+ messages = [
73
+ {"role": "system", "content": "You are a helpful assistant."},
74
+ {"role": "user", "content": prompt}
75
+ ]
76
+
77
+ text = tokenizer.apply_chat_template(
78
+ messages,
79
+ tokenize=False,
80
+ add_generation_prompt=True
81
+ )
82
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
83
+
84
+ # Fast-dLLM v2 parallel decoding
85
+ gen_ids = model.generate(
86
+ inputs["input_ids"],
87
+ tokenizer=tokenizer,
88
+ max_new_tokens=512,
89
+ small_block_size=8,
90
+ threshold=0.9,
91
+ )
92
+
93
+ response = tokenizer.decode(
94
+ gen_ids[0][inputs["input_ids"].shape[1]:],
95
+ skip_special_tokens=True
96
+ )
97
+ print(response)
98
+ ```
99
+
100
+ ---
101
+
102
+ ## 📊 Performance & Benchmarks
103
+
104
+ ### ▶ Real-time Throughput
105
+ Fast-dLLM v2 offers **up to 2.54× higher throughput** than Qwen2.5-7B-Instruct, **without loss in quality**.
106
+
107
+ ![Throughput Comparison](assets/throughput.png)
108
+
109
+ ---
110
+
111
+ ### 🏆 Benchmark Results
112
+ We compare Fast-dLLM v2 against AR baselines and previous diffusion LLMs on diverse tasks:
113
+ HumanEval, MBPP (code), GSM8K, Math (reasoning), IFEval (instruction), MMLU, GPQA (knowledge QA).
114
+
115
+ - **1B group**: Fast-dLLM v2 (7B) achieves **best average score: 45.0**.
116
+ - **7B group**: Fast-dLLM v2 (7B) achieves **best average score: 60.3**, surpassing LLaDA and Dream models.
117
+
118
+ ![Benchmark Results](assets/benchmark_results.png)
119
+
120
+ ---
121
+
122
+ ## 📜 Citation
123
+
124
+ If you use Fast-dLLM v2 in your research or products, please cite:
125
+
126
+ ```bibtex
127
+ @misc{wu2025fastdllmv2efficientblockdiffusion,
128
+ title={Fast-dLLM v2: Efficient Block-Diffusion LLM},
129
+ author={Chengyue Wu and Hao Zhang and Shuchen Xue and Shizhe Diao and Yonggan Fu and Zhijian Liu and Pavlo Molchanov and Ping Luo and Song Han and Enze Xie},
130
+ year={2025},
131
+ eprint={2509.26328},
132
+ archivePrefix={arXiv},
133
+ primaryClass={cs.CL},
134
+ url={https://arxiv.org/abs/2509.26328},
135
+ }
136
+ ```
137
+
138
+ ---
139
+
140
+ ## 📄 License
141
+ Released under **Apache 2.0**, following the base Qwen2.5 license.
142
+
143
+ ---
144
+
145
+ ## 🔗 Resources
146
+ - 📄 [Paper](https://arxiv.org/abs/2509.26328)
147
+ - 💻 [Code](https://github.com/NVlabs/Fast-dLLM)
148
+ - 🤗 [HuggingFace Model](https://huggingface.co/Efficient-Large-Model/Fast_dLLM_7B)
assets/benchmark_results.png ADDED

Git LFS Details

  • SHA256: 9ef4dfb1d35ef1332f9dca4072c2e7727ed761f7f635f9ac891f9b81a54adee7
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/throughput.png ADDED

Git LFS Details

  • SHA256: 4f208427d0eeda6fc5e65316aa50b9d5e43ecd38a38fdf3929dc6691bad02079
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
assets/training_recipe.png ADDED

Git LFS Details

  • SHA256: b2267f5d41fa4264816e870afa1353f4811b38865c04acdc9f3e4f04f5e3eb0c
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
assets/visualization_animation.gif ADDED

Git LFS Details

  • SHA256: 2c4c7fb54af204ea8cc03a8dadc9dde8dc8fb5ac514ead026a5d2833ee3aad37
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
modeling.py CHANGED
@@ -555,7 +555,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
555
  top_p=0.95,
556
  temperature=0,
557
  use_block_cache=False,
558
- block_cache_refresh_interval=16,
559
  **kwargs
560
  ):
561
  num_blocks = max_new_tokens // block_size
@@ -581,7 +580,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
581
  x_init = torch.cat([input_ids, x_init], dim=1)
582
 
583
  x_t = x_init.clone()
584
- step = 0
585
  block_past_key_values = None
586
  while True:
587
  if stop_token in x_t[:, prompt_length:]:
@@ -612,7 +610,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
612
  break
613
 
614
  if use_block_cache:
615
- if step % block_cache_refresh_interval == 0 or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
616
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
617
  logits, block_past_key_values = output.logits, output.block_past_key_values
618
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
@@ -638,7 +636,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
638
 
639
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
640
 
641
- step += 1
642
  input_ids = x_t
643
  # Truncate stop_token
644
  if stop_token in input_ids[:, original_input_length:]:
 
555
  top_p=0.95,
556
  temperature=0,
557
  use_block_cache=False,
 
558
  **kwargs
559
  ):
560
  num_blocks = max_new_tokens // block_size
 
580
  x_init = torch.cat([input_ids, x_init], dim=1)
581
 
582
  x_t = x_init.clone()
 
583
  block_past_key_values = None
584
  while True:
585
  if stop_token in x_t[:, prompt_length:]:
 
610
  break
611
 
612
  if use_block_cache:
613
+ if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
614
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
615
  logits, block_past_key_values = output.logits, output.block_past_key_values
616
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
 
636
 
637
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
638
 
 
639
  input_ids = x_t
640
  # Truncate stop_token
641
  if stop_token in input_ids[:, original_input_length:]: