Rocketknight1 HF Staff commited on
Commit
4524677
·
verified ·
1 Parent(s): 7c7a3ca

Clean up tool use snippet

Browse files
Files changed (1) hide show
  1. README.md +5 -5
README.md CHANGED
@@ -164,18 +164,18 @@ def get_current_weather(location: str, format: str):
164
  conversation = [{"role": "user", "content": "What's the weather like in Paris?"}]
165
  tools = [get_current_weather]
166
 
167
- # render the tool use prompt as a string:
168
- tool_use_prompt = tokenizer.apply_chat_template(
169
  conversation,
170
  tools=tools,
171
- tokenize=False,
172
  add_generation_prompt=True,
 
 
173
  )
174
 
175
- inputs = tokenizer(tool_use_prompt, return_tensors="pt")
176
-
177
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
178
 
 
179
  outputs = model.generate(**inputs, max_new_tokens=1000)
180
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
181
  ```
 
164
  conversation = [{"role": "user", "content": "What's the weather like in Paris?"}]
165
  tools = [get_current_weather]
166
 
167
+ # format and tokenize the tool use prompt
168
+ inputs = tokenizer.apply_chat_template(
169
  conversation,
170
  tools=tools,
 
171
  add_generation_prompt=True,
172
+ return_dict=True,
173
+ return_tensors="pt",
174
  )
175
 
 
 
176
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
177
 
178
+ inputs.to(model.device)
179
  outputs = model.generate(**inputs, max_new_tokens=1000)
180
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
181
  ```