DeathReaper0965 commited on
Commit
42b244c
·
1 Parent(s): 125c431

Update the deprecated Flash Attention call parameter in from_pretrained() method

Browse files
Files changed (1) hide show
  1. README.md +1 -1
README.md CHANGED
@@ -118,7 +118,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
118
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
119
  tokenizer = AutoTokenizer.from_pretrained(model_id)
120
 
121
- + model = AutoModelForCausalLM.from_pretrained(model_id, use_flash_attention_2=True)
122
 
123
  text = "Hello my name is"
124
  + inputs = tokenizer(text, return_tensors="pt").to(0)
 
118
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
119
  tokenizer = AutoTokenizer.from_pretrained(model_id)
120
 
121
+ + model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
122
 
123
  text = "Hello my name is"
124
  + inputs = tokenizer(text, return_tensors="pt").to(0)