Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
8-bit precision

Fixing nested JSON args parsing for tool-calls in streaming

#4
Files changed (1) hide show
  1. nemotron_toolcall_parser_streaming.py +131 -137
nemotron_toolcall_parser_streaming.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  from collections.abc import Sequence
3
  from random import choices
4
  from string import ascii_letters, digits
5
- from typing import Union
6
 
7
  import partial_json_parser
8
  import regex as re
@@ -16,8 +16,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
16
  FunctionCall, ToolCall)
17
  from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
  ToolParser, ToolParserManager)
19
- from vllm.entrypoints.openai.tool_parsers.utils import (
20
- extract_intermediate_diff)
21
  from vllm.logger import init_logger
22
  from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
23
 
@@ -61,6 +59,7 @@ class NemotronToolParser(ToolParser):
61
  self.current_tool_name_sent: bool = False
62
  self.streamed_args_for_tool: list[str] = [
63
  ] # map what has been streamed for each tool so far to a list
 
64
  self.bot_token = "<TOOLCALL>"
65
  self.bot_token_id = self.vocab.get(self.bot_token)
66
  logger.info(f"Nemotron Tool Parser: bot_token: {self.bot_token}, bot_token_id: {self.bot_token_id}")
@@ -75,6 +74,116 @@ class NemotronToolParser(ToolParser):
75
  # a forthcoming <TOOLCALL> or </TOOLCALL> tag in streaming.
76
  self._pending_tag_buffer: str = ""
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def adjust_request(
79
  self, request: ChatCompletionRequest) -> ChatCompletionRequest:
80
  if not isinstance(
@@ -175,64 +284,13 @@ class NemotronToolParser(ToolParser):
175
  # if candidates tool call tokens are in the tokens generated so far, that
176
  # means we're parsing as tool calls now. Suppress streaming if we are
177
  # currently generating any prefix of the start or end tag.
 
178
  try:
179
  start_token = self.bot_token
180
  end_token = f"</{self.bot_token[1:]}" if self.bot_token.startswith('<') else None
181
 
182
- # Handle potential start of tool call tags by buffering partial sequences
183
- if delta_text == '<' and not self._pending_tag_buffer:
184
- # Start buffering a potential tag
185
- self._pending_tag_buffer = '<'
186
- return None
187
-
188
- # If we have a pending tag buffer, accumulate and decide
189
- if self._pending_tag_buffer:
190
- # Accumulate the current token into the buffer
191
- self._pending_tag_buffer += delta_text
192
-
193
- # Extract just the alphabetic part after '<'
194
- alpha_part = ""
195
- for i in range(1, len(self._pending_tag_buffer)):
196
- if self._pending_tag_buffer[i].isalpha():
197
- alpha_part += self._pending_tag_buffer[i].upper()
198
- else:
199
- break
200
-
201
-
202
- # Check if we have a complete opening tag '<TOOLCALL>'
203
- if '<TOOLCALL>' in self._pending_tag_buffer:
204
- # We have the complete opening tag - stop buffering and let normal processing take over
205
- buffered_content = self._pending_tag_buffer
206
- self._pending_tag_buffer = ""
207
-
208
- # Update the text variables to include the buffered content
209
- updated_current_text = previous_text + buffered_content
210
- updated_delta_text = buffered_content # The entire buffered content is the delta
211
-
212
- # Continue processing with the complete tool call content
213
- current_text = updated_current_text
214
- delta_text = updated_delta_text
215
- # Fall through to normal processing
216
- elif self._pending_tag_buffer.startswith('</'):
217
- # End tag pattern - keep buffering until we see if it's a valid end tag
218
- return None
219
- elif alpha_part and "TOOLCALL".startswith(alpha_part) and len(alpha_part) < 8:
220
- # Could be building to TOOLCALL and haven't completed it yet - keep buffering
221
- return None
222
- elif len(alpha_part) > 0 and not "TOOLCALL".startswith(alpha_part):
223
- # Alphabetic content that definitely won't become TOOLCALL - flush as content
224
- content_to_flush = self._pending_tag_buffer
225
- self._pending_tag_buffer = ""
226
- return DeltaMessage(content=content_to_flush)
227
- else:
228
- # Keep buffering - not enough info yet
229
- return None
230
-
231
- # Suppress ANY partial prefix of the start/end tag to avoid leaking tag characters.
232
- if any(current_text.endswith(start_token[:k]) for k in range(1, len(start_token))):
233
- return None
234
- if end_token and any(current_text.endswith(end_token[:k]) for k in range(1, len(end_token))):
235
- return None
236
  except Exception:
237
  # Fallback to conservative checks in case of any issues
238
  if current_text.endswith('<') or current_text.endswith('<T') or current_text.endswith('<TO') or current_text.endswith('<TOOL') or current_text.endswith('<TOOLCALL'):
@@ -241,12 +299,10 @@ class NemotronToolParser(ToolParser):
241
  # if the tool call token is not in the tokens generated so far, append
242
  # output to contents since it's not a tool
243
  if self.bot_token not in current_text:
244
- # If we were buffering a partial tag and reached here, flush it first.
245
- if self._pending_tag_buffer:
246
- content_to_flush = self._pending_tag_buffer + delta_text
247
- self._pending_tag_buffer = ""
248
- return DeltaMessage(content=content_to_flush)
249
- return DeltaMessage(content=delta_text)
250
 
251
  # bit mask flags for partial JSON parsing. If the name hasn't been
252
  # sent yet, don't allow sending
@@ -272,7 +328,8 @@ class NemotronToolParser(ToolParser):
272
  try:
273
  tool_call_arr: list[dict] = partial_json_parser.loads(
274
  parsable_arr, flags)
275
- except partial_json_parser.core.exceptions.MalformedJSON:
 
276
  return None
277
 
278
  current_tool_call: dict = tool_call_arr[self.current_tool_id] \
@@ -315,6 +372,7 @@ class NemotronToolParser(ToolParser):
315
  self.current_tool_id = len(tool_call_arr) - 1
316
  self.current_tool_name_sent = False
317
  self.streamed_args_for_tool.append("")
 
318
  return delta
319
 
320
  # case: update an existing tool - this is handled below
@@ -345,10 +403,6 @@ class NemotronToolParser(ToolParser):
345
  self.current_tool_id].get("arguments")
346
  cur_arguments = current_tool_call.get("arguments")
347
 
348
- new_text = delta_text.replace("\'", "\"")
349
- if ('"}' in new_text):
350
- new_text = new_text[:new_text.rindex('"}')]
351
-
352
  if not cur_arguments and not prev_arguments:
353
 
354
  delta = None
@@ -357,47 +411,11 @@ class NemotronToolParser(ToolParser):
357
  "INVARIANT - impossible to have arguments reset "
358
  "mid-arguments")
359
  delta = None
360
- elif cur_arguments and not prev_arguments:
361
  cur_arguments_json = json.dumps(cur_arguments,
362
  ensure_ascii=False)
363
- streamed_prefix = self.streamed_args_for_tool[
364
- self.current_tool_id]
365
-
366
- # The issue: partial JSON parser auto-completes incomplete strings
367
- # e.g., {"location": " becomes {"location": ""} in parsed result
368
- # We need to handle this by detecting when the parsed result has auto-completed empty strings
369
-
370
- # Check if this looks like an auto-completed partial string
371
- if (cur_arguments_json.endswith('": ""}') and
372
- not streamed_prefix and
373
- '": ""' in cur_arguments_json):
374
- # This is likely auto-completed - remove the auto-completed empty string
375
- # e.g., {"location": ""} -> {"location": "
376
- closing_pos = cur_arguments_json.rfind('": ""}')
377
- if closing_pos != -1:
378
- arguments_delta = cur_arguments_json[:closing_pos + 4] # Keep up to ": "
379
- else:
380
- arguments_delta = cur_arguments_json
381
- else:
382
- # Normal case - use diff calculation
383
- if cur_arguments_json.startswith(streamed_prefix):
384
- arguments_delta = cur_arguments_json[len(streamed_prefix):]
385
- else:
386
- # Fallback: compute diff when prefix does not match.
387
- arguments_delta = extract_intermediate_diff(
388
- cur_arguments_json, streamed_prefix)
389
-
390
- # Do not include a trailing '}' in the very first
391
- # arguments chunk; defer it to the end-of-call flush to
392
- # avoid prematurely closing the JSON object.
393
- if (not self.streamed_args_for_tool[self.current_tool_id]
394
- and not end_of_call and arguments_delta
395
- and arguments_delta.endswith('}')):
396
- arguments_delta = arguments_delta[:-1]
397
- # if there is an auto-completed closing quote '"' before the }, strip it too
398
- # e.g., {"color_hex": "#"} -> {"color_hex": "#"} -> {"color_hex": "#"}
399
- if arguments_delta.endswith('"'):
400
- arguments_delta = arguments_delta[:-1]
401
  if arguments_delta:
402
  delta = DeltaMessage(tool_calls=[
403
  DeltaToolCall(index=self.current_tool_id,
@@ -407,26 +425,8 @@ class NemotronToolParser(ToolParser):
407
  ])
408
  self.streamed_args_for_tool[
409
  self.current_tool_id] += arguments_delta
410
- else:
411
- delta = None
412
-
413
- elif cur_arguments and prev_arguments:
414
- cur_args_json = json.dumps(cur_arguments,
415
- ensure_ascii=False)
416
- prev_args_json = json.dumps(prev_arguments,
417
- ensure_ascii=False)
418
-
419
- argument_diff = extract_intermediate_diff(
420
- cur_args_json, prev_args_json)
421
- if argument_diff:
422
- delta = DeltaMessage(tool_calls=[
423
- DeltaToolCall(index=self.current_tool_id,
424
- function=DeltaFunctionCall(
425
- arguments=argument_diff).model_dump(
426
- exclude_none=True))
427
- ])
428
- self.streamed_args_for_tool[
429
- self.current_tool_id] += argument_diff
430
  else:
431
  # Do not flush final JSON here; let the serving layer
432
  # compute a minimal remaining suffix on finish.
@@ -447,19 +447,12 @@ class NemotronToolParser(ToolParser):
447
  if cur_arguments is not None:
448
  cur_args_json = json.dumps(cur_arguments,
449
  ensure_ascii=False)
450
- streamed_prefix = self.streamed_args_for_tool[
451
- self.current_tool_id]
452
-
453
- if cur_args_json.startswith(streamed_prefix):
454
- remaining_suffix = cur_args_json[len(
455
- streamed_prefix):]
456
- else:
457
- remaining_suffix = extract_intermediate_diff(
458
- cur_args_json, streamed_prefix)
459
 
460
  # Only send remaining suffix if it's non-empty and contains meaningful content
461
  # (not just whitespace or single characters like closing braces)
462
- if remaining_suffix and remaining_suffix.strip() and len(remaining_suffix.strip()) > 0:
463
  extra = DeltaToolCall(
464
  index=self.current_tool_id,
465
  function=DeltaFunctionCall(
@@ -474,6 +467,7 @@ class NemotronToolParser(ToolParser):
474
  delta.tool_calls = [extra]
475
  self.streamed_args_for_tool[
476
  self.current_tool_id] += remaining_suffix
 
477
  else:
478
  pass
479
  except Exception:
 
2
  from collections.abc import Sequence
3
  from random import choices
4
  from string import ascii_letters, digits
5
+ from typing import Optional, Union
6
 
7
  import partial_json_parser
8
  import regex as re
 
16
  FunctionCall, ToolCall)
17
  from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
  ToolParser, ToolParserManager)
 
 
19
  from vllm.logger import init_logger
20
  from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
21
 
 
59
  self.current_tool_name_sent: bool = False
60
  self.streamed_args_for_tool: list[str] = [
61
  ] # map what has been streamed for each tool so far to a list
62
+ self.tool_args_emitted: list[bool] = []
63
  self.bot_token = "<TOOLCALL>"
64
  self.bot_token_id = self.vocab.get(self.bot_token)
65
  logger.info(f"Nemotron Tool Parser: bot_token: {self.bot_token}, bot_token_id: {self.bot_token_id}")
 
74
  # a forthcoming <TOOLCALL> or </TOOLCALL> tag in streaming.
75
  self._pending_tag_buffer: str = ""
76
 
77
+ @staticmethod
78
+ def _strip_trailing_auto_closers(chunk: str) -> str:
79
+ """
80
+ Remove parser auto-completed closing braces/brackets plus trailing whitespace.
81
+ These should be flushed only when a tool call completes to avoid duplicate
82
+ argument fragments.
83
+ """
84
+ idx = len(chunk)
85
+ while idx > 0 and chunk[idx - 1] in " \t\r\n}]":
86
+ idx -= 1
87
+ # Remove trailing non-escaped double quotes (partial JSON auto-closes strings)
88
+ while idx > 0 and chunk[idx - 1] == '"':
89
+ # keep escaped quotes (\"), only strip bare ones
90
+ if idx - 2 >= 0 and chunk[idx - 2] == '\\':
91
+ break
92
+ idx -= 1
93
+ return chunk[:idx]
94
+
95
+ @staticmethod
96
+ def _common_prefix_len(left: str, right: str) -> int:
97
+ """
98
+ Return the length of the shared prefix between left and right strings.
99
+ """
100
+ max_len = min(len(left), len(right))
101
+ idx = 0
102
+ while idx < max_len and left[idx] == right[idx]:
103
+ idx += 1
104
+ return idx
105
+
106
+ def _compute_arguments_delta(self, cur_arguments_json: str,
107
+ end_of_call: bool) -> str:
108
+ """
109
+ Determine the incremental suffix to stream for the current tool call.
110
+ Ensures we only emit monotonic chunks by trimming our tracked prefix to
111
+ the longest common prefix with the latest JSON snapshot.
112
+ """
113
+ tool_idx = self.current_tool_id
114
+ if tool_idx < 0 or tool_idx >= len(self.streamed_args_for_tool):
115
+ return ""
116
+
117
+ streamed_prefix = self.streamed_args_for_tool[tool_idx]
118
+ had_any = (self.tool_args_emitted[tool_idx]
119
+ if tool_idx < len(self.tool_args_emitted) else False)
120
+
121
+ lcp_len = self._common_prefix_len(cur_arguments_json,
122
+ streamed_prefix)
123
+ if lcp_len != len(streamed_prefix):
124
+ streamed_prefix = streamed_prefix[:lcp_len]
125
+ self.streamed_args_for_tool[tool_idx] = streamed_prefix
126
+
127
+ if (not had_any and not end_of_call and lcp_len == 0
128
+ and cur_arguments_json.endswith('": ""}')
129
+ and '": ""' in cur_arguments_json):
130
+ closing_pos = cur_arguments_json.rfind('": ""}')
131
+ if closing_pos != -1:
132
+ arguments_delta = cur_arguments_json[:closing_pos + 4]
133
+ else:
134
+ arguments_delta = cur_arguments_json
135
+ else:
136
+ arguments_delta = cur_arguments_json[lcp_len:]
137
+
138
+ if not arguments_delta:
139
+ return ""
140
+
141
+ if not end_of_call:
142
+ arguments_delta = self._strip_trailing_auto_closers(
143
+ arguments_delta)
144
+
145
+ if (not had_any and not end_of_call and arguments_delta
146
+ and arguments_delta.endswith('}')):
147
+ arguments_delta = arguments_delta[:-1]
148
+ if arguments_delta.endswith('"'):
149
+ arguments_delta = arguments_delta[:-1]
150
+
151
+ return arguments_delta
152
+
153
+ def _visible_delta_outside_tool(self, delta_text: str,
154
+ start_token: Optional[str],
155
+ end_token: Optional[str]) -> str:
156
+ """
157
+ Consume characters that could begin a tool tag. Only suppress the exact
158
+ <TOOLCALL> / </TOOLCALL> sequences, and let everything else (e.g. </think>)
159
+ pass through untouched.
160
+ """
161
+ if not delta_text:
162
+ return delta_text
163
+
164
+ visible: list[str] = []
165
+ for ch in delta_text:
166
+ if self._pending_tag_buffer or ch == '<':
167
+ self._pending_tag_buffer += ch
168
+
169
+ if start_token and start_token.startswith(self._pending_tag_buffer):
170
+ if self._pending_tag_buffer == start_token:
171
+ self._pending_tag_buffer = ""
172
+ continue
173
+
174
+ if end_token and end_token.startswith(self._pending_tag_buffer):
175
+ if self._pending_tag_buffer == end_token:
176
+ self._pending_tag_buffer = ""
177
+ continue
178
+
179
+ # Not a tool tag; flush buffered characters as normal content.
180
+ visible.append(self._pending_tag_buffer)
181
+ self._pending_tag_buffer = ""
182
+ else:
183
+ visible.append(ch)
184
+
185
+ return "".join(visible)
186
+
187
  def adjust_request(
188
  self, request: ChatCompletionRequest) -> ChatCompletionRequest:
189
  if not isinstance(
 
284
  # if candidates tool call tokens are in the tokens generated so far, that
285
  # means we're parsing as tool calls now. Suppress streaming if we are
286
  # currently generating any prefix of the start or end tag.
287
+ visible_delta_text = delta_text
288
  try:
289
  start_token = self.bot_token
290
  end_token = f"</{self.bot_token[1:]}" if self.bot_token.startswith('<') else None
291
 
292
+ visible_delta_text = self._visible_delta_outside_tool(
293
+ delta_text, start_token, end_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  except Exception:
295
  # Fallback to conservative checks in case of any issues
296
  if current_text.endswith('<') or current_text.endswith('<T') or current_text.endswith('<TO') or current_text.endswith('<TOOL') or current_text.endswith('<TOOLCALL'):
 
299
  # if the tool call token is not in the tokens generated so far, append
300
  # output to contents since it's not a tool
301
  if self.bot_token not in current_text:
302
+ if visible_delta_text:
303
+ return DeltaMessage(content=visible_delta_text)
304
+ # still waiting on a potential tag, so emit nothing yet
305
+ return None
 
 
306
 
307
  # bit mask flags for partial JSON parsing. If the name hasn't been
308
  # sent yet, don't allow sending
 
328
  try:
329
  tool_call_arr: list[dict] = partial_json_parser.loads(
330
  parsable_arr, flags)
331
+ except (partial_json_parser.core.exceptions.MalformedJSON,
332
+ json.JSONDecodeError, ValueError):
333
  return None
334
 
335
  current_tool_call: dict = tool_call_arr[self.current_tool_id] \
 
372
  self.current_tool_id = len(tool_call_arr) - 1
373
  self.current_tool_name_sent = False
374
  self.streamed_args_for_tool.append("")
375
+ self.tool_args_emitted.append(False)
376
  return delta
377
 
378
  # case: update an existing tool - this is handled below
 
403
  self.current_tool_id].get("arguments")
404
  cur_arguments = current_tool_call.get("arguments")
405
 
 
 
 
 
406
  if not cur_arguments and not prev_arguments:
407
 
408
  delta = None
 
411
  "INVARIANT - impossible to have arguments reset "
412
  "mid-arguments")
413
  delta = None
414
+ elif cur_arguments:
415
  cur_arguments_json = json.dumps(cur_arguments,
416
  ensure_ascii=False)
417
+ arguments_delta = self._compute_arguments_delta(
418
+ cur_arguments_json, end_of_call)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  if arguments_delta:
420
  delta = DeltaMessage(tool_calls=[
421
  DeltaToolCall(index=self.current_tool_id,
 
425
  ])
426
  self.streamed_args_for_tool[
427
  self.current_tool_id] += arguments_delta
428
+ self.tool_args_emitted[
429
+ self.current_tool_id] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  else:
431
  # Do not flush final JSON here; let the serving layer
432
  # compute a minimal remaining suffix on finish.
 
447
  if cur_arguments is not None:
448
  cur_args_json = json.dumps(cur_arguments,
449
  ensure_ascii=False)
450
+ remaining_suffix = self._compute_arguments_delta(
451
+ cur_args_json, end_of_call=True)
 
 
 
 
 
 
 
452
 
453
  # Only send remaining suffix if it's non-empty and contains meaningful content
454
  # (not just whitespace or single characters like closing braces)
455
+ if remaining_suffix and remaining_suffix.strip():
456
  extra = DeltaToolCall(
457
  index=self.current_tool_id,
458
  function=DeltaFunctionCall(
 
467
  delta.tool_calls = [extra]
468
  self.streamed_args_for_tool[
469
  self.current_tool_id] += remaining_suffix
470
+ self.tool_args_emitted[self.current_tool_id] = True
471
  else:
472
  pass
473
  except Exception: