pandora-s Rocketknight1 HF Staff commited on
Commit
2036c14
·
verified ·
1 Parent(s): 16bd787

Add full tool calling support to chat template (#59)

Browse files

- Add full tool calling support to chat template (b3ff8777adb9ce554824709c2f8979963f60fd34)
- Remove template spacing to match Nemo tokenizer (c8db75491df1a70d505eae010c1d0e3afafe2c10)
- Update tokenizer_config.json (20b40c380523d1ebca5abca9c03039f09485c31f)
- Add example to README (38e450d9223d3d6911eb3f5da97943b0a4ed3d43)
- "[/TOOL_RESULTS] " -> "[/TOOL_RESULTS]" (94291a6fbba8d5ab0497321e9671bf7e999dd541)


Co-authored-by: Matthew Carrigan <[email protected]>

Files changed (2) hide show
  1. README.md +48 -0
  2. tokenizer_config.json +1 -1
README.md CHANGED
@@ -208,6 +208,54 @@ chatbot = pipeline("text-generation", model="mistralai/Mistral-Nemo-Instruct-240
208
  chatbot(messages)
209
  ```
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  > [!TIP]
212
  > Unlike previous Mistral models, Mistral Nemo requires smaller temperatures. We recommend to use a temperature of 0.3.
213
 
 
208
  chatbot(messages)
209
  ```
210
 
211
+ ## Function calling with `transformers`
212
+
213
+ To use this example, you'll need `transformers` version 4.42.0 or higher. Please see the
214
+ [function calling guide](https://huggingface.co/docs/transformers/main/chat_templating#advanced-tool-use--function-calling)
215
+ in the `transformers` docs for more information.
216
+
217
+ ```python
218
+ from transformers import AutoModelForCausalLM, AutoTokenizer
219
+ import torch
220
+
221
+ model_id = "mistralai/Mistral-Nemo-Instruct-2407"
222
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
223
+
224
+ def get_current_weather(location: str, format: str):
225
+ """
226
+ Get the current weather
227
+
228
+ Args:
229
+ location: The city and state, e.g. San Francisco, CA
230
+ format: The temperature unit to use. Infer this from the users location. (choices: ["celsius", "fahrenheit"])
231
+ """
232
+ pass
233
+
234
+ conversation = [{"role": "user", "content": "What's the weather like in Paris?"}]
235
+ tools = [get_current_weather]
236
+
237
+ # render the tool use prompt as a string:
238
+ tool_use_prompt = tokenizer.apply_chat_template(
239
+ conversation,
240
+ tools=tools,
241
+ tokenize=False,
242
+ add_generation_prompt=True,
243
+ )
244
+
245
+ inputs = tokenizer(tool_use_prompt, return_tensors="pt")
246
+
247
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
248
+
249
+ outputs = model.generate(**inputs, max_new_tokens=1000)
250
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
251
+ ```
252
+
253
+ Note that, for reasons of space, this example does not show a complete cycle of calling a tool and adding the tool call and tool
254
+ results to the chat history so that the model can use them in its next generation. For a full tool calling example, please
255
+ see the [function calling guide](https://huggingface.co/docs/transformers/main/chat_templating#advanced-tool-use--function-calling),
256
+ and note that Mistral **does** use tool call IDs, so these must be included in your tool calls and tool results. They should be
257
+ exactly 9 alphanumeric characters.
258
+
259
  > [!TIP]
260
  > Unlike previous Mistral models, Mistral Nemo requires smaller temperatures. We recommend to use a temperature of 0.3.
261
 
tokenizer_config.json CHANGED
@@ -8005,7 +8005,7 @@
8005
  }
8006
  },
8007
  "bos_token": "<s>",
8008
- "chat_template": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.last and system_message is defined %}\n {{- '[INST]' + system_message + '\\n\\n' + message['content'] + '[/INST]' }}\n {%- else %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
8009
  "clean_up_tokenization_spaces": false,
8010
  "eos_token": "</s>",
8011
  "model_max_length": 1000000000000000019884624838656,
 
8005
  }
8006
  },
8007
  "bos_token": "<s>",
8008
+ "chat_template": "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{%- for message in loop_messages | rejectattr(\"role\", \"equalto\", \"tool\") | rejectattr(\"role\", \"equalto\", \"tool_results\") | selectattr(\"tool_calls\", \"undefined\") %}\n {%- if (message[\"role\"] == \"user\") != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message[\"role\"] == \"tool_calls\" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
8009
  "clean_up_tokenization_spaces": false,
8010
  "eos_token": "</s>",
8011
  "model_max_length": 1000000000000000019884624838656,