From 9e80cc085ce3293fdc47ccab3048c689fe81620f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=BC=E6=AC=A3?= Date: Wed, 30 Aug 2023 15:04:13 +0800 Subject: [PATCH] add function calling support --- README.md | 28 +- README_CN.md | 28 +- README_JA.md | 26 +- examples/function_call_examples.py | 240 +++++++++++++++++ openai_api.py | 420 +++++++++++++++++++++++------ 5 files changed, 628 insertions(+), 114 deletions(-) create mode 100644 examples/function_call_examples.py diff --git a/README.md b/README.md index 6934af9..214abe2 100644 --- a/README.md +++ b/README.md @@ -321,7 +321,7 @@ openai.api_key = "none" # create a request activating streaming response for chunk in openai.ChatCompletion.create( - model="Qwen-7B", + model="Qwen", messages=[ {"role": "user", "content": "你好"} ], @@ -333,7 +333,7 @@ for chunk in openai.ChatCompletion.create( # create a request not activating streaming response response = openai.ChatCompletion.create( - model="Qwen-7B", + model="Qwen", messages=[ {"role": "user", "content": "你好"} ], @@ -349,6 +349,8 @@ print(response.choices[0].message.content)

+Function calling is also supported (but only when `stream=False` for the moment). See the [example usage](examples/function_call_examples.py) here. + ## Deployment It is simple to run the model on CPU, which requires your specification of device: @@ -371,22 +373,22 @@ Then you can run the 7B chat model on 2 GPUs using the above scripts. Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance. -| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | -| :------------ | :-----------------------: | :----------------------: | :----------------------: | -| GPT-4 | 95% | **0.90** | 15% | -| GPT-3.5 | 85% | 0.88 | 75% | -| **Qwen-7B** | **99%** | 0.89 | **9.7%** | +| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | +|:-----------------| :-----------------------: | :----------------------: | :----------------------: | +| GPT-4 | 95% | **0.90** | 15% | +| GPT-3.5 | 85% | 0.88 | 75% | +| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** | For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks. Additionally, we provide experimental results to show its capabilities of playing as an agent. See [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) for more information. Its performance on the run-mode benchmark provided by Hugging Face is as follows: -| Model | Tool Selection↑ | Tool Used↑ | Code↑ | -| :---------------- | :----------------: | :-----------: | :---------: | -| GPT-4 | **100** | **100** | **97.41** | -| GPT-3.5 | 95.37 | 96.30 | 87.04 | -| StarCoder-15.5B | 87.04 | 87.96 | 68.89 | -| **Qwen-7B** | 90.74 | 92.59 | 74.07 | +| Model | Tool Selection↑ | Tool Used↑ | Code↑ | +|:-----------------| :----------------: | :-----------: | :---------: | +| GPT-4 | **100** | **100** | **97.41** | +| GPT-3.5 | 95.37 | 96.30 | 87.04 | +| StarCoder-15.5B | 87.04 | 87.96 | 68.89 | +| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
diff --git a/README_CN.md b/README_CN.md index 7b8df0d..507b554 100644 --- a/README_CN.md +++ b/README_CN.md @@ -327,7 +327,7 @@ openai.api_key = "none" # 使用流式回复的请求 for chunk in openai.ChatCompletion.create( - model="Qwen-7B", + model="Qwen", messages=[ {"role": "user", "content": "你好"} ], @@ -339,7 +339,7 @@ for chunk in openai.ChatCompletion.create( # 不使用流式回复的请求 response = openai.ChatCompletion.create( - model="Qwen-7B", + model="Qwen", messages=[ {"role": "user", "content": "你好"} ], @@ -355,6 +355,8 @@ print(response.choices[0].message.content)

+该接口也支持函数调用(Function Calling),但暂时仅限 `stream=False` 时能生效。用法见[函数调用示例](examples/function_call_examples.py)。 + ## 部署 在CPU上运行非常简单,使用方法如下所示: @@ -377,11 +379,11 @@ model = load_model_on_gpus('Qwen/Qwen-7B-Chat', num_gpus=2) Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。在我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力,并发现Qwen-7B-Chat能够取得稳定的表现。 -| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | -|:------------|:----------------------:|:----------------------:|:----------------------:| -| GPT-4 | 95% | **0.90** | 15% | -| GPT-3.5 | 85% | 0.88 | 75% | -| **Qwen-7B** | **99%** | 0.89 | **9.7%** | +| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | +|:-----------------|:----------------------:|:----------------------:|:----------------------:| +| GPT-4 | 95% | **0.90** | 15% | +| GPT-3.5 | 85% | 0.88 | 75% | +| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** | 我们提供了文档说明如何根据ReAct Prompting的原则写作你的prompt。 @@ -389,12 +391,12 @@ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct 此外,我们还提供了实验结果表明我们的模型扮演Agent的能力。请阅读相关文档[链接](https://huggingface.co/docs/transformers/transformers_agents)了解更多信息。模型在Hugging Face提供的评测数据集上表现如下: -| Model | Tool Selection↑ | Tool Used↑ | Code↑ | -|:---------------|:---------------:|:-----------:|:---------:| -|GPT-4 | **100** | **100** | **97.41** | -|GPT-3.5 | 95.37 | 96.30 | 87.04 | -|StarCoder-15.5B | 87.04 | 87.96 | 68.89 | -| **Qwen-7B** | 90.74 | 92.59 | 74.07 | +| Model | Tool Selection↑ | Tool Used↑ | Code↑ | +|:-----------------|:---------------:|:-----------:|:---------:| +| GPT-4 | **100** | **100** | **97.41** | +| GPT-3.5 | 95.37 | 96.30 | 87.04 | +| StarCoder-15.5B | 87.04 | 87.96 | 68.89 | +| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
diff --git a/README_JA.md b/README_JA.md index e80ef62..155fec4 100644 --- a/README_JA.md +++ b/README_JA.md @@ -331,7 +331,7 @@ openai.api_key = "none" # create a request activating streaming response for chunk in openai.ChatCompletion.create( - model="Qwen-7B", + model="Qwen", messages=[ {"role": "user", "content": "你好"} ], @@ -342,7 +342,7 @@ for chunk in openai.ChatCompletion.create( # create a request not activating streaming response response = openai.ChatCompletion.create( - model="Qwen-7B", + model="Qwen", messages=[ {"role": "user", "content": "你好"} ], @@ -381,22 +381,22 @@ model = load_model_on_gpus('Qwen/Qwen-7B-Chat', num_gpus=2) Qwen-7B-Chat は、API、データベース、モデルなど、ツールの利用に特化して最適化されており、ユーザは独自の Qwen-7B ベースの LangChain、エージェント、コードインタプリタを構築することができます。ツール利用能力を評価するための評価[ベンチマーク](eval/EVALUATION.md)では、Qwen-7B は安定した性能に達しています。 [](https://) -| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | -|:------------|:----------------------:|:----------------------:|:----------------------:| -| GPT-4 | 95% | **0.90** | 15% | -| GPT-3.5 | 85% | 0.88 | 75% | -| **Qwen-7B** | **99%** | 0.89 | **9.7%** | +| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | +|:-----------------|:----------------------:|:----------------------:|:----------------------:| +| GPT-4 | 95% | **0.90** | 15% | +| GPT-3.5 | 85% | 0.88 | 75% | +| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** | ReAct プロンプトの書き方や使い方については、[ReAct の例](examples/react_prompt.md)を参照してください。ツールを使用することで、モデルがよりよいタスクを実行できるようになります。 さらに、エージェントとしての能力を示す実験結果を提供する。詳細は [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) を参照。Hugging Face が提供するランモードベンチマークでの性能は以下の通りです: -| Model | Tool Selection↑ | Tool Used↑ | Code↑ | -|:---------------|:---------------:|:-----------:|:---------:| -|GPT-4 | **100** | **100** | **97.41** | -|GPT-3.5 | 95.37 | 96.30 | 87.04 | -|StarCoder-15.5B | 87.04 | 87.96 | 68.89 | -| **Qwen-7B** | 90.74 | 92.59 | 74.07 | +| Model | Tool Selection↑ | Tool Used↑ | Code↑ | +|:-----------------|:---------------:|:-----------:|:---------:| +| GPT-4 | **100** | **100** | **97.41** | +| GPT-3.5 | 95.37 | 96.30 | 87.04 | +| StarCoder-15.5B | 87.04 | 87.96 | 68.89 | +| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
diff --git a/examples/function_call_examples.py b/examples/function_call_examples.py new file mode 100644 index 0000000..be65678 --- /dev/null +++ b/examples/function_call_examples.py @@ -0,0 +1,240 @@ +# Reference: https://openai.com/blog/function-calling-and-other-api-updates + +import openai + +# To start an OpenAI-like Qwen server, use the following commands: +# git clone https://github.com/QwenLM/Qwen-7B; +# cd Qwen-7B; +# pip install fastapi uvicorn openai pydantic sse_starlette; +# python openai_api.py; +# +# Then configure the api_base and api_key in your client: +openai.api_base = "http://localhost:8000/v1" +openai.api_key = "none" + + +def call_qwen(messages, functions=None): + print(messages) + if functions: + response = openai.ChatCompletion.create( + model="Qwen", messages=messages, functions=functions + ) + else: + response = openai.ChatCompletion.create(model="Qwen", messages=messages) + print(response) + print(response.choices[0].message.content) + return response + + +def test_1(): + messages = [{"role": "user", "content": "你好"}] + call_qwen(messages) + messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"}) + + messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"}) + call_qwen(messages) + messages.append( + { + "role": "assistant", + "content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……", + } + ) + + messages.append({"role": "user", "content": "给这个故事起一个标题"}) + call_qwen(messages) + + +def test_2(): + functions = [ + { + "name_for_human": "谷歌搜索", + "name_for_model": "google_search", + "description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。" + + " Format the arguments as a JSON object.", + "parameters": [ + { + "name": "search_query", + "description": "搜索关键词或短语", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + { + "name_for_human": "文生图", + "name_for_model": "image_gen", + "description_for_model": "文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。" + + " Format the arguments as a JSON object.", + "parameters": [ + { + "name": "prompt", + "description": "英文关键词,描述了希望图像具有什么内容", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + ] + + messages = [{"role": "user", "content": "你好"}] + call_qwen(messages, functions) + messages.append( + {"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"}, + ) + + messages.append({"role": "user", "content": "谁是周杰伦"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用Google搜索查找相关信息。", + "function_call": { + "name": "google_search", + "arguments": '{"search_query": "周杰伦"}', + }, + } + ) + + messages.append( + { + "role": "function", + "name": "google_search", + "content": "Jay Chou is a Taiwanese singer.", + } + ) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "周杰伦(Jay Chou)是一位来自台湾的歌手。", + }, + ) + + messages.append({"role": "user", "content": "他老婆是谁"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用Google搜索查找相关信息。", + "function_call": { + "name": "google_search", + "arguments": '{"search_query": "周杰伦 老婆"}', + }, + } + ) + + messages.append( + {"role": "function", "name": "google_search", "content": "Hannah Quinlivan"} + ) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "周杰伦的老婆是Hannah Quinlivan。", + }, + ) + + messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。", + "function_call": { + "name": "image_gen", + "arguments": '{"prompt": "cute black cat"}', + }, + } + ) + + messages.append( + { + "role": "function", + "name": "image_gen", + "content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}', + } + ) + call_qwen(messages, functions) + + +def test_3(): + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + + messages = [ + { + "role": "user", + # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, + # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. + "content": "波士顿天气如何?", + } + ] + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": None, + "function_call": { + "name": "get_current_weather", + "arguments": '{"location": "Boston, MA"}', + }, + }, + ) + + messages.append( + { + "role": "function", + "name": "get_current_weather", + "content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}', + } + ) + call_qwen(messages, functions) + + +def test_4(): + from langchain.chat_models import ChatOpenAI + from langchain.agents import load_tools, initialize_agent, AgentType + + llm = ChatOpenAI( + model_name="Qwen", + openai_api_base="http://localhost:8000/v1", + openai_api_key="EMPTY", + streaming=False, + ) + tools = load_tools( + ["arxiv"], + ) + agent_chain = initialize_agent( + tools, + llm, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + ) + # TODO: The performance is okay with Chinese prompts, but not so good when it comes to English. + agent_chain.run("查一下论文 1605.08386 的信息") + + +if __name__ == "__main__": + print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") + test_1() + print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") + test_2() + print("### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###") + test_3() + print("### Test Case 4 - Use LangChain (接入Langchain) ###") + test_4() diff --git a/openai_api.py b/openai_api.py index 42b0841..fd551bd 100644 --- a/openai_api.py +++ b/openai_api.py @@ -3,22 +3,26 @@ # Usage: python openai_api.py # Visit http://localhost:8000/docs for documents. -from argparse import ArgumentParser +import re +import copy +import json import time +from argparse import ArgumentParser +from contextlib import asynccontextmanager +from typing import Dict, List, Literal, Optional, Union + import torch import uvicorn -from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from contextlib import asynccontextmanager -from typing import Any, Dict, List, Literal, Optional, Union -from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM +from pydantic import BaseModel, Field +from sse_starlette.sse import EventSourceResponse +from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation import GenerationConfig -from sse_starlette.sse import ServerSentEvent, EventSourceResponse @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -52,8 +56,9 @@ class ModelList(BaseModel): class ChatMessage(BaseModel): - role: Literal["user", "assistant", "system"] - content: str + role: Literal["user", "assistant", "system", "function"] + content: Optional[str] + function_call: Optional[Dict] = None class DeltaMessage(BaseModel): @@ -64,17 +69,18 @@ class DeltaMessage(BaseModel): class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] + functions: Optional[List[Dict]] = None temperature: Optional[float] = None top_p: Optional[float] = None max_length: Optional[int] = None stream: Optional[bool] = False - stop: Optional[List[str]] = [] + stop: Optional[List[str]] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - finish_reason: Literal["stop", "length"] + finish_reason: Literal["stop", "length", "function_call"] class ChatCompletionResponseStreamChoice(BaseModel): @@ -86,7 +92,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponse(BaseModel): model: str object: Literal["chat.completion", "chat.completion.chunk"] - choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + choices: List[ + Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] + ] created: Optional[int] = Field(default_factory=lambda: int(time.time())) @@ -97,70 +105,319 @@ async def list_models(): return ModelList(data=[model_card]) +# To work around that unpleasant leading-\n tokenization issue! +def add_extra_stop_words(stop_words): + if stop_words: + _stop_words = [] + _stop_words.extend(stop_words) + for x in stop_words: + s = x.lstrip("\n") + if s and (s not in _stop_words): + _stop_words.append(s) + return _stop_words + return stop_words + + +def trim_stop_words(response, stop_words): + if stop_words: + for stop in stop_words: + idx = response.find(stop) + if idx != -1: + response = response[:idx] + return response + + +TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" + +REACT_INSTRUCTION = """Answer the following questions as best you can. You have acesss to the following APIs: + +{tools_text} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tools_name_text}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin!""" + +_TEXT_COMPLETION_CMD = object() + + +# +# Temporarily, the system role does not work as expected. +# We advise that you write the setups for role-play in your query, +# i.e., use the user role instead of the system role. +# +# TODO: Use real system role when the model is ready. +# +def parse_messages(messages, functions): + if all(m.role != "user" for m in messages): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting at least one user message.", + ) + + messages = copy.deepcopy(messages) + default_system = "You are a helpful assistant." + system = "" + if messages[0].role == "system": + system = messages.pop(0).content.lstrip("\n").rstrip() + if system == default_system: + system = "" + + if functions: + tools_text = [] + tools_name_text = [] + for func_info in functions: + name = func_info.get("name", "") + name_m = func_info.get("name_for_model", name) + name_h = func_info.get("name_for_human", name) + desc = func_info.get("description", "") + desc_m = func_info.get("description_for_model", desc) + tool = TOOL_DESC.format( + name_for_model=name_m, + name_for_human=name_h, + # Hint: You can add the following format requirements in description: + # "Format the arguments as a JSON object." + # "Enclose the code within triple backticks (`) at the beginning and end of the code." + description_for_model=desc_m, + parameters=json.dumps(func_info["parameters"], ensure_ascii=False), + ) + tools_text.append(tool) + tools_name_text.append(name_m) + tools_text = "\n\n".join(tools_text) + tools_name_text = ", ".join(tools_name_text) + system += "\n\n" + REACT_INSTRUCTION.format( + tools_text=tools_text, + tools_name_text=tools_name_text, + ) + system = system.lstrip("\n").rstrip() + + dummy_thought = { + "en": "\nThought: I now know the final answer.\nFinal answer: ", + "zh": "\nThought: 我会作答了。\nFinal answer: ", + } + + _messages = messages + messages = [] + for m_idx, m in enumerate(_messages): + role, content, func_call = m.role, m.content, m.function_call + if content: + content = content.lstrip("\n").rstrip() + if role == "function": + if (len(messages) == 0) or (messages[-1].role != "assistant"): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role assistant before role function.", + ) + messages[-1].content += f"\nObservation: {content}" + if m_idx == len(_messages) - 1: + messages[-1].content += "\nThought:" + elif role == "assistant": + if len(messages) == 0: + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role user before role assistant.", + ) + last_msg = messages[-1].content + last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 + if func_call is None: + if functions: + content = dummy_thought["zh" if last_msg_has_zh else "en"] + content + else: + f_name, f_args = func_call["name"], func_call["arguments"] + if not content: + if last_msg_has_zh: + content = f"Thought: 我可以使用 {f_name} API。" + else: + content = f"Thought: I can use {f_name}." + content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}" + if messages[-1].role == "user": + messages.append( + ChatMessage(role="assistant", content=content.lstrip("\n").rstrip()) + ) + else: + messages[-1].content += content + elif role == "user": + messages.append( + ChatMessage(role="user", content=content.lstrip("\n").rstrip()) + ) + else: + raise HTTPException( + status_code=400, detail=f"Invalid request: Incorrect role {role}." + ) + + query = _TEXT_COMPLETION_CMD + if messages[-1].role == "user": + query = messages[-1].content + messages = messages[:-1] + + if len(messages) % 2 != 0: + raise HTTPException(status_code=400, detail="Invalid request") + + history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] + for i in range(0, len(messages), 2): + if messages[i].role == "user" and messages[i + 1].role == "assistant": + usr_msg = messages[i].content.lstrip("\n").rstrip() + bot_msg = messages[i + 1].content.lstrip("\n").rstrip() + if system and (i == len(messages) - 2): + usr_msg = f"{system}\n\nQuestion: {usr_msg}" + system = "" + for t in dummy_thought.values(): + t = t.lstrip("\n") + if bot_msg.startswith(t) and ("\nAction: " in bot_msg): + bot_msg = bot_msg[len(t) :] + history.append([usr_msg, bot_msg]) + else: + raise HTTPException( + status_code=400, + detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", + ) + if system: + assert query is not _TEXT_COMPLETION_CMD + query = f"{system}\n\nQuestion: {query}" + return query, history + + +def parse_response(response): + func_name, func_args = "", "" + i = response.rfind("\nAction:") + j = response.rfind("\nAction Input:") + k = response.rfind("\nObservation:") + if 0 <= i < j: # If the text has `Action` and `Action input`, + if k < j: # but does not contain `Observation`, + # then it is likely that `Observation` is omitted by the LLM, + # because the output text may have discarded the stop word. + response = response.rstrip() + "\nObservation:" # Add it back. + k = response.rfind("\nObservation:") + func_name = response[i + len("\nAction:") : j].strip() + func_args = response[j + len("\nAction Input:") : k].strip() + if func_name: + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=response[:i], + function_call={"name": func_name, "arguments": func_args}, + ), + finish_reason="function_call", + ) + return choice_data + z = response.rfind("\nFinal Answer: ") + if z >= 0: + response = response[z + len("\nFinal Answer: ") :] + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + return choice_data + + +# completion mode, not chat mode +def text_complete_last_message(history, stop_words_ids): + im_start = "<|im_start|>" + im_end = "<|im_end|>" + prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" + for i, (query, response) in enumerate(history): + query = query.lstrip("\n").rstrip() + response = response.lstrip("\n").rstrip() + prompt += f"\n{im_start}user\n{query}{im_end}" + prompt += f"\n{im_start}assistant\n{response}{im_end}" + prompt = prompt[: -len(im_end)] + + _stop_words_ids = [tokenizer.encode(im_end)] + if stop_words_ids: + for s in stop_words_ids: + _stop_words_ids.append(s) + stop_words_ids = _stop_words_ids + + input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device) + output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0] + output = tokenizer.decode(output, errors="ignore") + assert output.startswith(prompt) + output = output[len(prompt) :] + output = trim_stop_words(output, ["<|endoftext|>", im_end]) + print(f"\n{prompt}\n\n{output}\n") + return output + + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer - if request.messages[-1].role != "user": - raise HTTPException(status_code=400, detail="Invalid request") - query = request.messages[-1].content - stop_words = request.stop - stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words)))) - prev_messages = request.messages[:-1] - # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query. - # if len(prev_messages) > 0 and prev_messages[0].role == "system": - # query = prev_messages.pop(0).content + query - - history = [] - if len(prev_messages) % 2 == 0: - for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": - history.append([prev_messages[i].content, prev_messages[i+1].content]) - else: - raise HTTPException(status_code=400, detail="Invalid request.") - else: - raise HTTPException(status_code=400, detail="Invalid request.") + stop_words = add_extra_stop_words(request.stop) + if request.functions: + stop_words = stop_words or [] + if "Observation:" not in stop_words: + stop_words.append("Observation:") + + query, history = parse_messages(request.messages, request.functions) if request.stream: + if request.functions: + raise HTTPException( + status_code=400, + detail="Invalid request: Function calling is not yet implemented for stream mode.", + ) generate = predict(query, history, request.model, stop_words) return EventSourceResponse(generate, media_type="text/event-stream") - if stop_words: - react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words] - response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens) - for stop_ in stop_words: - if response.endswith(stop_): - response = response[:response.find(stop_)] + stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None + if query is _TEXT_COMPLETION_CMD: + response = text_complete_last_message(history, stop_words_ids=stop_words_ids) else: - response, _ = model.chat(tokenizer, query, history=history) - - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop" + response, _ = model.chat( + tokenizer, + query, + history=history, + stop_words_ids=stop_words_ids, + append_history=False, + ) + print(f"\n{history}\n{query}\n\n{response}\n") + response = trim_stop_words(response, stop_words) + if request.functions: + choice_data = parse_response(response) + else: + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + return ChatCompletionResponse( + model=request.model, choices=[choice_data], object="chat.completion" ) - return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") - -async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]): +async def predict( + query: str, history: List[List[str]], model_id: str, stop_words: List[str] +): global model, tokenizer - assert stop_words == [], "in stream format, stop word is output" choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant"), - finish_reason=None + index=0, delta=DeltaMessage(role="assistant"), finish_reason=None + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.model_dump_json(exclude_unset=True)) current_length = 0 + stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None if stop_words: - react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words] - response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens) - else: - response_generator = model.chat_stream(tokenizer, query, history=history) - + # TODO: It's a little bit tricky to trim stop words in the stream mode. + raise HTTPException( + status_code=400, + detail="Invalid request: custom stop words are not yet supported for stream mode.", + ) + response_generator = model.chat_stream( + tokenizer, query, history=history, stop_words_ids=stop_words_ids + ) for new_response in response_generator: if len(new_response) == current_length: continue @@ -169,32 +426,41 @@ async def predict(query: str, history: List[List[str]], model_id: str, stop_word current_length = len(new_response) choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=new_text), - finish_reason=None + index=0, delta=DeltaMessage(content=new_text), finish_reason=None + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.model_dump_json(exclude_unset=True)) - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(), - finish_reason="stop" + index=0, delta=DeltaMessage(), finish_reason="stop" + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.model_dump_json(exclude_unset=True)) - yield '[DONE]' + yield "[DONE]" + def _get_args(): parser = ArgumentParser() - parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat', - help="Checkpoint name or path, default to %(default)r") - parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") - parser.add_argument("--server-port", type=int, default=8000, - help="Demo server port.") - parser.add_argument("--server-name", type=str, default="127.0.0.1", - help="Demo server name.") + parser.add_argument( + "-c", + "--checkpoint-path", + type=str, + default="QWen/QWen-7B-Chat", + help="Checkpoint name or path, default to %(default)r", + ) + parser.add_argument( + "--cpu-only", action="store_true", help="Run demo with CPU only" + ) + parser.add_argument( + "--server-port", type=int, default=8000, help="Demo server port." + ) + parser.add_argument( + "--server-name", type=str, default="127.0.0.1", help="Demo server name." + ) args = parser.parse_args() return args @@ -204,7 +470,9 @@ if __name__ == "__main__": args = _get_args() tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True, resume_download=True, + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, ) if args.cpu_only: @@ -220,7 +488,9 @@ if __name__ == "__main__": ).eval() model.generation_config = GenerationConfig.from_pretrained( - args.checkpoint_path, trust_remote_code=True, resume_download=True, + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, ) uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)