Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
suhara commited on
Commit
56ed81d
·
verified ·
1 Parent(s): 8550779

Upload nemotron_toolcall_parser_no_streaming.py

Browse files
nemotron_toolcall_parser_no_streaming.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import ast
4
+ import json
5
+ import re
6
+ from collections.abc import Sequence
7
+ from typing import Union
8
+
9
+ import partial_json_parser
10
+ from partial_json_parser.core.options import Allow
11
+
12
+ from vllm.entrypoints.openai.protocol import (
13
+ ChatCompletionRequest,
14
+ DeltaFunctionCall, DeltaMessage,
15
+ DeltaToolCall,
16
+ ExtractedToolCallInformation,
17
+ FunctionCall,
18
+ ToolCall,
19
+ )
20
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
21
+ ToolParser,
22
+ ToolParserManager,
23
+ )
24
+ from vllm.logger import init_logger
25
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
26
+ from vllm.utils import random_uuid
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ @ToolParserManager.register_module("nemotron_json")
32
+ class NemotronJSONToolParser(ToolParser):
33
+
34
+ def __init__(self, tokenizer: AnyTokenizer):
35
+ super().__init__(tokenizer)
36
+
37
+ self.current_tool_name_sent: bool = False
38
+ self.prev_tool_call_arr: list[dict] = []
39
+ self.current_tool_id: int = -1
40
+ self.streamed_args_for_tool: list[str] = []
41
+
42
+ self.tool_call_start_token: str = "<TOOLCALL>"
43
+ self.tool_call_end_token: str = "</TOOLCALL>"
44
+
45
+ self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
46
+
47
+ def extract_tool_calls(
48
+ self,
49
+ model_output: str,
50
+ request: ChatCompletionRequest,
51
+ ) -> ExtractedToolCallInformation:
52
+
53
+ if self.tool_call_start_token not in model_output:
54
+ return ExtractedToolCallInformation(
55
+ tools_called=False,
56
+ tool_calls=[],
57
+ content=model_output,
58
+ )
59
+
60
+ else:
61
+
62
+ try:
63
+ str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip()
64
+ if not str_tool_calls.startswith("["):
65
+ str_tool_calls = "[" + str_tool_calls
66
+ if not str_tool_calls.endswith("]"):
67
+ str_tool_calls = "]" + str_tool_calls
68
+ json_tool_calls = json.loads(str_tool_calls)
69
+ tool_calls = []
70
+ for tool_call in json_tool_calls:
71
+ try:
72
+ tool_calls.append(ToolCall(
73
+ type="function",
74
+ function=FunctionCall(
75
+ name=tool_call["name"],
76
+ arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \
77
+ if isinstance(tool_call["arguments"], dict) else tool_call["arguments"],
78
+ ),
79
+ ))
80
+ except:
81
+ continue
82
+
83
+ content = model_output[:model_output.rfind(self.tool_call_start_token)]
84
+
85
+ return ExtractedToolCallInformation(
86
+ tools_called=True,
87
+ tool_calls=tool_calls,
88
+ content=content if content else None,
89
+ )
90
+
91
+ except Exception:
92
+ logger.exception(f"Error in extracting tool call from response. Response: {model_output}")
93
+ return ExtractedToolCallInformation(
94
+ tools_called=False,
95
+ tool_calls=[],
96
+ content=model_output,
97
+ )
98
+
99
+ def extract_tool_calls_streaming(
100
+ self,
101
+ previous_text: str,
102
+ current_text: str,
103
+ delta_text: str,
104
+ previous_token_ids: Sequence[int],
105
+ current_token_ids: Sequence[int],
106
+ delta_token_ids: Sequence[int],
107
+ request: ChatCompletionRequest,
108
+ ) -> Union[DeltaMessage, None]:
109
+
110
+ raise NotImplementedError("Tool calling is not supported in streaming mode!")