Commit
·
42b244c
1
Parent(s):
125c431
Update the deprecated Flash Attention call parameter in from_pretrained() method
Browse files
README.md
CHANGED
|
@@ -118,7 +118,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
| 118 |
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 119 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 120 |
|
| 121 |
-
+ model = AutoModelForCausalLM.from_pretrained(model_id,
|
| 122 |
|
| 123 |
text = "Hello my name is"
|
| 124 |
+ inputs = tokenizer(text, return_tensors="pt").to(0)
|
|
|
|
| 118 |
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 119 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 120 |
|
| 121 |
+
+ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
|
| 122 |
|
| 123 |
text = "Hello my name is"
|
| 124 |
+ inputs = tokenizer(text, return_tensors="pt").to(0)
|