diff --git a/README.md b/README.md
index 6934af9..214abe2 100644
--- a/README.md
+++ b/README.md
@@ -321,7 +321,7 @@ openai.api_key = "none"
# create a request activating streaming response
for chunk in openai.ChatCompletion.create(
- model="Qwen-7B",
+ model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@@ -333,7 +333,7 @@ for chunk in openai.ChatCompletion.create(
# create a request not activating streaming response
response = openai.ChatCompletion.create(
- model="Qwen-7B",
+ model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@@ -349,6 +349,8 @@ print(response.choices[0].message.content)
+Function calling is also supported (but only when `stream=False` for the moment). See the [example usage](examples/function_call_examples.py) here.
+
## Deployment
It is simple to run the model on CPU, which requires your specification of device:
@@ -371,22 +373,22 @@ Then you can run the 7B chat model on 2 GPUs using the above scripts.
Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
-| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
-| :------------ | :-----------------------: | :----------------------: | :----------------------: |
-| GPT-4 | 95% | **0.90** | 15% |
-| GPT-3.5 | 85% | 0.88 | 75% |
-| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
+| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
+|:-----------------| :-----------------------: | :----------------------: | :----------------------: |
+| GPT-4 | 95% | **0.90** | 15% |
+| GPT-3.5 | 85% | 0.88 | 75% |
+| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks.
Additionally, we provide experimental results to show its capabilities of playing as an agent. See [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) for more information. Its performance on the run-mode benchmark provided by Hugging Face is as follows:
-| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
-| :---------------- | :----------------: | :-----------: | :---------: |
-| GPT-4 | **100** | **100** | **97.41** |
-| GPT-3.5 | 95.37 | 96.30 | 87.04 |
-| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
-| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
+| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
+|:-----------------| :----------------: | :-----------: | :---------: |
+| GPT-4 | **100** | **100** | **97.41** |
+| GPT-3.5 | 95.37 | 96.30 | 87.04 |
+| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
+| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
diff --git a/README_CN.md b/README_CN.md
index 7b8df0d..507b554 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -327,7 +327,7 @@ openai.api_key = "none"
# 使用流式回复的请求
for chunk in openai.ChatCompletion.create(
- model="Qwen-7B",
+ model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@@ -339,7 +339,7 @@ for chunk in openai.ChatCompletion.create(
# 不使用流式回复的请求
response = openai.ChatCompletion.create(
- model="Qwen-7B",
+ model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@@ -355,6 +355,8 @@ print(response.choices[0].message.content)
+该接口也支持函数调用(Function Calling),但暂时仅限 `stream=False` 时能生效。用法见[函数调用示例](examples/function_call_examples.py)。
+
## 部署
在CPU上运行非常简单,使用方法如下所示:
@@ -377,11 +379,11 @@ model = load_model_on_gpus('Qwen/Qwen-7B-Chat', num_gpus=2)
Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。在我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力,并发现Qwen-7B-Chat能够取得稳定的表现。
-| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
-|:------------|:----------------------:|:----------------------:|:----------------------:|
-| GPT-4 | 95% | **0.90** | 15% |
-| GPT-3.5 | 85% | 0.88 | 75% |
-| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
+| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
+|:-----------------|:----------------------:|:----------------------:|:----------------------:|
+| GPT-4 | 95% | **0.90** | 15% |
+| GPT-3.5 | 85% | 0.88 | 75% |
+| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
我们提供了文档说明如何根据ReAct Prompting的原则写作你的prompt。
@@ -389,12 +391,12 @@ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct
此外,我们还提供了实验结果表明我们的模型扮演Agent的能力。请阅读相关文档[链接](https://huggingface.co/docs/transformers/transformers_agents)了解更多信息。模型在Hugging Face提供的评测数据集上表现如下:
-| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
-|:---------------|:---------------:|:-----------:|:---------:|
-|GPT-4 | **100** | **100** | **97.41** |
-|GPT-3.5 | 95.37 | 96.30 | 87.04 |
-|StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
-| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
+| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
+|:-----------------|:---------------:|:-----------:|:---------:|
+| GPT-4 | **100** | **100** | **97.41** |
+| GPT-3.5 | 95.37 | 96.30 | 87.04 |
+| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
+| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
diff --git a/README_JA.md b/README_JA.md
index e80ef62..155fec4 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -331,7 +331,7 @@ openai.api_key = "none"
# create a request activating streaming response
for chunk in openai.ChatCompletion.create(
- model="Qwen-7B",
+ model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@@ -342,7 +342,7 @@ for chunk in openai.ChatCompletion.create(
# create a request not activating streaming response
response = openai.ChatCompletion.create(
- model="Qwen-7B",
+ model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@@ -381,22 +381,22 @@ model = load_model_on_gpus('Qwen/Qwen-7B-Chat', num_gpus=2)
Qwen-7B-Chat は、API、データベース、モデルなど、ツールの利用に特化して最適化されており、ユーザは独自の Qwen-7B ベースの LangChain、エージェント、コードインタプリタを構築することができます。ツール利用能力を評価するための評価[ベンチマーク](eval/EVALUATION.md)では、Qwen-7B は安定した性能に達しています。
[](https://)
-| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
-|:------------|:----------------------:|:----------------------:|:----------------------:|
-| GPT-4 | 95% | **0.90** | 15% |
-| GPT-3.5 | 85% | 0.88 | 75% |
-| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
+| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
+|:-----------------|:----------------------:|:----------------------:|:----------------------:|
+| GPT-4 | 95% | **0.90** | 15% |
+| GPT-3.5 | 85% | 0.88 | 75% |
+| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
ReAct プロンプトの書き方や使い方については、[ReAct の例](examples/react_prompt.md)を参照してください。ツールを使用することで、モデルがよりよいタスクを実行できるようになります。
さらに、エージェントとしての能力を示す実験結果を提供する。詳細は [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) を参照。Hugging Face が提供するランモードベンチマークでの性能は以下の通りです:
-| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
-|:---------------|:---------------:|:-----------:|:---------:|
-|GPT-4 | **100** | **100** | **97.41** |
-|GPT-3.5 | 95.37 | 96.30 | 87.04 |
-|StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
-| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
+| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
+|:-----------------|:---------------:|:-----------:|:---------:|
+| GPT-4 | **100** | **100** | **97.41** |
+| GPT-3.5 | 95.37 | 96.30 | 87.04 |
+| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
+| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
diff --git a/examples/function_call_examples.py b/examples/function_call_examples.py
new file mode 100644
index 0000000..be65678
--- /dev/null
+++ b/examples/function_call_examples.py
@@ -0,0 +1,240 @@
+# Reference: https://openai.com/blog/function-calling-and-other-api-updates
+
+import openai
+
+# To start an OpenAI-like Qwen server, use the following commands:
+# git clone https://github.com/QwenLM/Qwen-7B;
+# cd Qwen-7B;
+# pip install fastapi uvicorn openai pydantic sse_starlette;
+# python openai_api.py;
+#
+# Then configure the api_base and api_key in your client:
+openai.api_base = "http://localhost:8000/v1"
+openai.api_key = "none"
+
+
+def call_qwen(messages, functions=None):
+ print(messages)
+ if functions:
+ response = openai.ChatCompletion.create(
+ model="Qwen", messages=messages, functions=functions
+ )
+ else:
+ response = openai.ChatCompletion.create(model="Qwen", messages=messages)
+ print(response)
+ print(response.choices[0].message.content)
+ return response
+
+
+def test_1():
+ messages = [{"role": "user", "content": "你好"}]
+ call_qwen(messages)
+ messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"})
+
+ messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"})
+ call_qwen(messages)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……",
+ }
+ )
+
+ messages.append({"role": "user", "content": "给这个故事起一个标题"})
+ call_qwen(messages)
+
+
+def test_2():
+ functions = [
+ {
+ "name_for_human": "谷歌搜索",
+ "name_for_model": "google_search",
+ "description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。"
+ + " Format the arguments as a JSON object.",
+ "parameters": [
+ {
+ "name": "search_query",
+ "description": "搜索关键词或短语",
+ "required": True,
+ "schema": {"type": "string"},
+ }
+ ],
+ },
+ {
+ "name_for_human": "文生图",
+ "name_for_model": "image_gen",
+ "description_for_model": "文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。"
+ + " Format the arguments as a JSON object.",
+ "parameters": [
+ {
+ "name": "prompt",
+ "description": "英文关键词,描述了希望图像具有什么内容",
+ "required": True,
+ "schema": {"type": "string"},
+ }
+ ],
+ },
+ ]
+
+ messages = [{"role": "user", "content": "你好"}]
+ call_qwen(messages, functions)
+ messages.append(
+ {"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"},
+ )
+
+ messages.append({"role": "user", "content": "谁是周杰伦"})
+ call_qwen(messages, functions)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "Thought: 我应该使用Google搜索查找相关信息。",
+ "function_call": {
+ "name": "google_search",
+ "arguments": '{"search_query": "周杰伦"}',
+ },
+ }
+ )
+
+ messages.append(
+ {
+ "role": "function",
+ "name": "google_search",
+ "content": "Jay Chou is a Taiwanese singer.",
+ }
+ )
+ call_qwen(messages, functions)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "周杰伦(Jay Chou)是一位来自台湾的歌手。",
+ },
+ )
+
+ messages.append({"role": "user", "content": "他老婆是谁"})
+ call_qwen(messages, functions)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "Thought: 我应该使用Google搜索查找相关信息。",
+ "function_call": {
+ "name": "google_search",
+ "arguments": '{"search_query": "周杰伦 老婆"}',
+ },
+ }
+ )
+
+ messages.append(
+ {"role": "function", "name": "google_search", "content": "Hannah Quinlivan"}
+ )
+ call_qwen(messages, functions)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "周杰伦的老婆是Hannah Quinlivan。",
+ },
+ )
+
+ messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"})
+ call_qwen(messages, functions)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。",
+ "function_call": {
+ "name": "image_gen",
+ "arguments": '{"prompt": "cute black cat"}',
+ },
+ }
+ )
+
+ messages.append(
+ {
+ "role": "function",
+ "name": "image_gen",
+ "content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}',
+ }
+ )
+ call_qwen(messages, functions)
+
+
+def test_3():
+ functions = [
+ {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+ },
+ "required": ["location"],
+ },
+ }
+ ]
+
+ messages = [
+ {
+ "role": "user",
+ # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts,
+ # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting.
+ "content": "波士顿天气如何?",
+ }
+ ]
+ call_qwen(messages, functions)
+ messages.append(
+ {
+ "role": "assistant",
+ "content": None,
+ "function_call": {
+ "name": "get_current_weather",
+ "arguments": '{"location": "Boston, MA"}',
+ },
+ },
+ )
+
+ messages.append(
+ {
+ "role": "function",
+ "name": "get_current_weather",
+ "content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}',
+ }
+ )
+ call_qwen(messages, functions)
+
+
+def test_4():
+ from langchain.chat_models import ChatOpenAI
+ from langchain.agents import load_tools, initialize_agent, AgentType
+
+ llm = ChatOpenAI(
+ model_name="Qwen",
+ openai_api_base="http://localhost:8000/v1",
+ openai_api_key="EMPTY",
+ streaming=False,
+ )
+ tools = load_tools(
+ ["arxiv"],
+ )
+ agent_chain = initialize_agent(
+ tools,
+ llm,
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
+ verbose=True,
+ )
+ # TODO: The performance is okay with Chinese prompts, but not so good when it comes to English.
+ agent_chain.run("查一下论文 1605.08386 的信息")
+
+
+if __name__ == "__main__":
+ print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###")
+ test_1()
+ print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###")
+ test_2()
+ print("### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###")
+ test_3()
+ print("### Test Case 4 - Use LangChain (接入Langchain) ###")
+ test_4()
diff --git a/openai_api.py b/openai_api.py
index 42b0841..fd551bd 100644
--- a/openai_api.py
+++ b/openai_api.py
@@ -3,22 +3,26 @@
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
-from argparse import ArgumentParser
+import re
+import copy
+import json
import time
+from argparse import ArgumentParser
+from contextlib import asynccontextmanager
+from typing import Dict, List, Literal, Optional, Union
+
import torch
import uvicorn
-from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
-from contextlib import asynccontextmanager
-from typing import Any, Dict, List, Literal, Optional, Union
-from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
+from pydantic import BaseModel, Field
+from sse_starlette.sse import EventSourceResponse
+from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
-from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
-async def lifespan(app: FastAPI): # collects GPU memory
+async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -52,8 +56,9 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel):
- role: Literal["user", "assistant", "system"]
- content: str
+ role: Literal["user", "assistant", "system", "function"]
+ content: Optional[str]
+ function_call: Optional[Dict] = None
class DeltaMessage(BaseModel):
@@ -64,17 +69,18 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
+ functions: Optional[List[Dict]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
- stop: Optional[List[str]] = []
+ stop: Optional[List[str]] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
- finish_reason: Literal["stop", "length"]
+ finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
@@ -86,7 +92,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
- choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
+ choices: List[
+ Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
+ ]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@@ -97,70 +105,319 @@ async def list_models():
return ModelList(data=[model_card])
+# To work around that unpleasant leading-\n tokenization issue!
+def add_extra_stop_words(stop_words):
+ if stop_words:
+ _stop_words = []
+ _stop_words.extend(stop_words)
+ for x in stop_words:
+ s = x.lstrip("\n")
+ if s and (s not in _stop_words):
+ _stop_words.append(s)
+ return _stop_words
+ return stop_words
+
+
+def trim_stop_words(response, stop_words):
+ if stop_words:
+ for stop in stop_words:
+ idx = response.find(stop)
+ if idx != -1:
+ response = response[:idx]
+ return response
+
+
+TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
+
+REACT_INSTRUCTION = """Answer the following questions as best you can. You have acesss to the following APIs:
+
+{tools_text}
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Action: the action to take, should be one of [{tools_name_text}]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!"""
+
+_TEXT_COMPLETION_CMD = object()
+
+
+#
+# Temporarily, the system role does not work as expected.
+# We advise that you write the setups for role-play in your query,
+# i.e., use the user role instead of the system role.
+#
+# TODO: Use real system role when the model is ready.
+#
+def parse_messages(messages, functions):
+ if all(m.role != "user" for m in messages):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid request: Expecting at least one user message.",
+ )
+
+ messages = copy.deepcopy(messages)
+ default_system = "You are a helpful assistant."
+ system = ""
+ if messages[0].role == "system":
+ system = messages.pop(0).content.lstrip("\n").rstrip()
+ if system == default_system:
+ system = ""
+
+ if functions:
+ tools_text = []
+ tools_name_text = []
+ for func_info in functions:
+ name = func_info.get("name", "")
+ name_m = func_info.get("name_for_model", name)
+ name_h = func_info.get("name_for_human", name)
+ desc = func_info.get("description", "")
+ desc_m = func_info.get("description_for_model", desc)
+ tool = TOOL_DESC.format(
+ name_for_model=name_m,
+ name_for_human=name_h,
+ # Hint: You can add the following format requirements in description:
+ # "Format the arguments as a JSON object."
+ # "Enclose the code within triple backticks (`) at the beginning and end of the code."
+ description_for_model=desc_m,
+ parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
+ )
+ tools_text.append(tool)
+ tools_name_text.append(name_m)
+ tools_text = "\n\n".join(tools_text)
+ tools_name_text = ", ".join(tools_name_text)
+ system += "\n\n" + REACT_INSTRUCTION.format(
+ tools_text=tools_text,
+ tools_name_text=tools_name_text,
+ )
+ system = system.lstrip("\n").rstrip()
+
+ dummy_thought = {
+ "en": "\nThought: I now know the final answer.\nFinal answer: ",
+ "zh": "\nThought: 我会作答了。\nFinal answer: ",
+ }
+
+ _messages = messages
+ messages = []
+ for m_idx, m in enumerate(_messages):
+ role, content, func_call = m.role, m.content, m.function_call
+ if content:
+ content = content.lstrip("\n").rstrip()
+ if role == "function":
+ if (len(messages) == 0) or (messages[-1].role != "assistant"):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid request: Expecting role assistant before role function.",
+ )
+ messages[-1].content += f"\nObservation: {content}"
+ if m_idx == len(_messages) - 1:
+ messages[-1].content += "\nThought:"
+ elif role == "assistant":
+ if len(messages) == 0:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid request: Expecting role user before role assistant.",
+ )
+ last_msg = messages[-1].content
+ last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
+ if func_call is None:
+ if functions:
+ content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
+ else:
+ f_name, f_args = func_call["name"], func_call["arguments"]
+ if not content:
+ if last_msg_has_zh:
+ content = f"Thought: 我可以使用 {f_name} API。"
+ else:
+ content = f"Thought: I can use {f_name}."
+ content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
+ if messages[-1].role == "user":
+ messages.append(
+ ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
+ )
+ else:
+ messages[-1].content += content
+ elif role == "user":
+ messages.append(
+ ChatMessage(role="user", content=content.lstrip("\n").rstrip())
+ )
+ else:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid request: Incorrect role {role}."
+ )
+
+ query = _TEXT_COMPLETION_CMD
+ if messages[-1].role == "user":
+ query = messages[-1].content
+ messages = messages[:-1]
+
+ if len(messages) % 2 != 0:
+ raise HTTPException(status_code=400, detail="Invalid request")
+
+ history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
+ for i in range(0, len(messages), 2):
+ if messages[i].role == "user" and messages[i + 1].role == "assistant":
+ usr_msg = messages[i].content.lstrip("\n").rstrip()
+ bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
+ if system and (i == len(messages) - 2):
+ usr_msg = f"{system}\n\nQuestion: {usr_msg}"
+ system = ""
+ for t in dummy_thought.values():
+ t = t.lstrip("\n")
+ if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
+ bot_msg = bot_msg[len(t) :]
+ history.append([usr_msg, bot_msg])
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
+ )
+ if system:
+ assert query is not _TEXT_COMPLETION_CMD
+ query = f"{system}\n\nQuestion: {query}"
+ return query, history
+
+
+def parse_response(response):
+ func_name, func_args = "", ""
+ i = response.rfind("\nAction:")
+ j = response.rfind("\nAction Input:")
+ k = response.rfind("\nObservation:")
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
+ if k < j: # but does not contain `Observation`,
+ # then it is likely that `Observation` is omitted by the LLM,
+ # because the output text may have discarded the stop word.
+ response = response.rstrip() + "\nObservation:" # Add it back.
+ k = response.rfind("\nObservation:")
+ func_name = response[i + len("\nAction:") : j].strip()
+ func_args = response[j + len("\nAction Input:") : k].strip()
+ if func_name:
+ choice_data = ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(
+ role="assistant",
+ content=response[:i],
+ function_call={"name": func_name, "arguments": func_args},
+ ),
+ finish_reason="function_call",
+ )
+ return choice_data
+ z = response.rfind("\nFinal Answer: ")
+ if z >= 0:
+ response = response[z + len("\nFinal Answer: ") :]
+ choice_data = ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(role="assistant", content=response),
+ finish_reason="stop",
+ )
+ return choice_data
+
+
+# completion mode, not chat mode
+def text_complete_last_message(history, stop_words_ids):
+ im_start = "<|im_start|>"
+ im_end = "<|im_end|>"
+ prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
+ for i, (query, response) in enumerate(history):
+ query = query.lstrip("\n").rstrip()
+ response = response.lstrip("\n").rstrip()
+ prompt += f"\n{im_start}user\n{query}{im_end}"
+ prompt += f"\n{im_start}assistant\n{response}{im_end}"
+ prompt = prompt[: -len(im_end)]
+
+ _stop_words_ids = [tokenizer.encode(im_end)]
+ if stop_words_ids:
+ for s in stop_words_ids:
+ _stop_words_ids.append(s)
+ stop_words_ids = _stop_words_ids
+
+ input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
+ output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0]
+ output = tokenizer.decode(output, errors="ignore")
+ assert output.startswith(prompt)
+ output = output[len(prompt) :]
+ output = trim_stop_words(output, ["<|endoftext|>", im_end])
+ print(f"