Merge pull request #259 from JianxinMa/main

add function calling support
main
Yang An 1 year ago committed by GitHub
commit 82c52f7d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -321,7 +321,7 @@ openai.api_key = "none"
# create a request activating streaming response
for chunk in openai.ChatCompletion.create(
model="Qwen-7B",
model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@ -333,7 +333,7 @@ for chunk in openai.ChatCompletion.create(
# create a request not activating streaming response
response = openai.ChatCompletion.create(
model="Qwen-7B",
model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@ -349,6 +349,8 @@ print(response.choices[0].message.content)
<br>
<p>
Function calling is also supported (but only when `stream=False` for the moment). See the [example usage](examples/function_call_examples.py) here.
## Deployment
It is simple to run the model on CPU, which requires your specification of device:
@ -372,21 +374,21 @@ Then you can run the 7B chat model on 2 GPUs using the above scripts.
Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
| :------------ | :-----------------------: | :----------------------: | :----------------------: |
|:-----------------| :-----------------------: | :----------------------: | :----------------------: |
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks.
Additionally, we provide experimental results to show its capabilities of playing as an agent. See [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) for more information. Its performance on the run-mode benchmark provided by Hugging Face is as follows:
| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
| :---------------- | :----------------: | :-----------: | :---------: |
|:-----------------| :----------------: | :-----------: | :---------: |
| GPT-4 | **100** | **100** | **97.41** |
| GPT-3.5 | 95.37 | 96.30 | 87.04 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
<br>

@ -327,7 +327,7 @@ openai.api_key = "none"
# 使用流式回复的请求
for chunk in openai.ChatCompletion.create(
model="Qwen-7B",
model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@ -339,7 +339,7 @@ for chunk in openai.ChatCompletion.create(
# 不使用流式回复的请求
response = openai.ChatCompletion.create(
model="Qwen-7B",
model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@ -355,6 +355,8 @@ print(response.choices[0].message.content)
<br>
<p>
该接口也支持函数调用Function Calling但暂时仅限 `stream=False` 时能生效。用法见[函数调用示例](examples/function_call_examples.py)。
## 部署
在CPU上运行非常简单使用方法如下所示
@ -378,10 +380,10 @@ model = load_model_on_gpus('Qwen/Qwen-7B-Chat', num_gpus=2)
Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。在我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力并发现Qwen-7B-Chat能够取得稳定的表现。
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
|:------------|:----------------------:|:----------------------:|:----------------------:|
|:-----------------|:----------------------:|:----------------------:|:----------------------:|
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
我们提供了文档说明如何根据ReAct Prompting的原则写作你的prompt。
@ -390,11 +392,11 @@ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct
此外我们还提供了实验结果表明我们的模型扮演Agent的能力。请阅读相关文档[链接](https://huggingface.co/docs/transformers/transformers_agents)了解更多信息。模型在Hugging Face提供的评测数据集上表现如下
| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
|:---------------|:---------------:|:-----------:|:---------:|
|GPT-4 | **100** | **100** | **97.41** |
|GPT-3.5 | 95.37 | 96.30 | 87.04 |
|StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
|:-----------------|:---------------:|:-----------:|:---------:|
| GPT-4 | **100** | **100** | **97.41** |
| GPT-3.5 | 95.37 | 96.30 | 87.04 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
<br>

@ -331,7 +331,7 @@ openai.api_key = "none"
# create a request activating streaming response
for chunk in openai.ChatCompletion.create(
model="Qwen-7B",
model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@ -342,7 +342,7 @@ for chunk in openai.ChatCompletion.create(
# create a request not activating streaming response
response = openai.ChatCompletion.create(
model="Qwen-7B",
model="Qwen",
messages=[
{"role": "user", "content": "你好"}
],
@ -382,21 +382,21 @@ Qwen-7B-Chat は、API、データベース、モデルなど、ツールの利
[](https://)
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
|:------------|:----------------------:|:----------------------:|:----------------------:|
|:-----------------|:----------------------:|:----------------------:|:----------------------:|
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
ReAct プロンプトの書き方や使い方については、[ReAct の例](examples/react_prompt.md)を参照してください。ツールを使用することで、モデルがよりよいタスクを実行できるようになります。
さらに、エージェントとしての能力を示す実験結果を提供する。詳細は [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) を参照。Hugging Face が提供するランモードベンチマークでの性能は以下の通りです:
| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
|:---------------|:---------------:|:-----------:|:---------:|
|GPT-4 | **100** | **100** | **97.41** |
|GPT-3.5 | 95.37 | 96.30 | 87.04 |
|StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
|:-----------------|:---------------:|:-----------:|:---------:|
| GPT-4 | **100** | **100** | **97.41** |
| GPT-3.5 | 95.37 | 96.30 | 87.04 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
<br>

@ -0,0 +1,240 @@
# Reference: https://openai.com/blog/function-calling-and-other-api-updates
import openai
# To start an OpenAI-like Qwen server, use the following commands:
# git clone https://github.com/QwenLM/Qwen-7B;
# cd Qwen-7B;
# pip install fastapi uvicorn openai pydantic sse_starlette;
# python openai_api.py;
#
# Then configure the api_base and api_key in your client:
openai.api_base = "http://localhost:8000/v1"
openai.api_key = "none"
def call_qwen(messages, functions=None):
print(messages)
if functions:
response = openai.ChatCompletion.create(
model="Qwen", messages=messages, functions=functions
)
else:
response = openai.ChatCompletion.create(model="Qwen", messages=messages)
print(response)
print(response.choices[0].message.content)
return response
def test_1():
messages = [{"role": "user", "content": "你好"}]
call_qwen(messages)
messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"})
messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"})
call_qwen(messages)
messages.append(
{
"role": "assistant",
"content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……",
}
)
messages.append({"role": "user", "content": "给这个故事起一个标题"})
call_qwen(messages)
def test_2():
functions = [
{
"name_for_human": "谷歌搜索",
"name_for_model": "google_search",
"description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。"
+ " Format the arguments as a JSON object.",
"parameters": [
{
"name": "search_query",
"description": "搜索关键词或短语",
"required": True,
"schema": {"type": "string"},
}
],
},
{
"name_for_human": "文生图",
"name_for_model": "image_gen",
"description_for_model": "文生图是一个AI绘画图像生成服务输入文本描述返回根据文本作画得到的图片的URL。"
+ " Format the arguments as a JSON object.",
"parameters": [
{
"name": "prompt",
"description": "英文关键词,描述了希望图像具有什么内容",
"required": True,
"schema": {"type": "string"},
}
],
},
]
messages = [{"role": "user", "content": "你好"}]
call_qwen(messages, functions)
messages.append(
{"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"},
)
messages.append({"role": "user", "content": "谁是周杰伦"})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "Thought: 我应该使用Google搜索查找相关信息。",
"function_call": {
"name": "google_search",
"arguments": '{"search_query": "周杰伦"}',
},
}
)
messages.append(
{
"role": "function",
"name": "google_search",
"content": "Jay Chou is a Taiwanese singer.",
}
)
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "周杰伦Jay Chou是一位来自台湾的歌手。",
},
)
messages.append({"role": "user", "content": "他老婆是谁"})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "Thought: 我应该使用Google搜索查找相关信息。",
"function_call": {
"name": "google_search",
"arguments": '{"search_query": "周杰伦 老婆"}',
},
}
)
messages.append(
{"role": "function", "name": "google_search", "content": "Hannah Quinlivan"}
)
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "周杰伦的老婆是Hannah Quinlivan。",
},
)
messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。",
"function_call": {
"name": "image_gen",
"arguments": '{"prompt": "cute black cat"}',
},
}
)
messages.append(
{
"role": "function",
"name": "image_gen",
"content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}',
}
)
call_qwen(messages, functions)
def test_3():
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
messages = [
{
"role": "user",
# Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts,
# but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting.
"content": "波士顿天气如何?",
}
]
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": None,
"function_call": {
"name": "get_current_weather",
"arguments": '{"location": "Boston, MA"}',
},
},
)
messages.append(
{
"role": "function",
"name": "get_current_weather",
"content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}',
}
)
call_qwen(messages, functions)
def test_4():
from langchain.chat_models import ChatOpenAI
from langchain.agents import load_tools, initialize_agent, AgentType
llm = ChatOpenAI(
model_name="Qwen",
openai_api_base="http://localhost:8000/v1",
openai_api_key="EMPTY",
streaming=False,
)
tools = load_tools(
["arxiv"],
)
agent_chain = initialize_agent(
tools,
llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
# TODO: The performance is okay with Chinese prompts, but not so good when it comes to English.
agent_chain.run("查一下论文 1605.08386 的信息")
if __name__ == "__main__":
print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###")
test_1()
print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###")
test_2()
print("### Test Case 3 - Use GPT-Style Functions (函数调用GPT格式) ###")
test_3()
print("### Test Case 4 - Use LangChain (接入Langchain) ###")
test_4()

@ -3,18 +3,22 @@
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
from argparse import ArgumentParser
import re
import copy
import json
import time
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from typing import Dict, List, Literal, Optional, Union
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
@ -52,8 +56,9 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
role: Literal["user", "assistant", "system", "function"]
content: Optional[str]
function_call: Optional[Dict] = None
class DeltaMessage(BaseModel):
@ -64,17 +69,18 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
functions: Optional[List[Dict]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
stop: Optional[List[str]] = []
stop: Optional[List[str]] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
@ -86,7 +92,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
choices: List[
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@ -97,70 +105,319 @@ async def list_models():
return ModelList(data=[model_card])
# To work around that unpleasant leading-\n tokenization issue!
def add_extra_stop_words(stop_words):
if stop_words:
_stop_words = []
_stop_words.extend(stop_words)
for x in stop_words:
s = x.lstrip("\n")
if s and (s not in _stop_words):
_stop_words.append(s)
return _stop_words
return stop_words
def trim_stop_words(response, stop_words):
if stop_words:
for stop in stop_words:
idx = response.find(stop)
if idx != -1:
response = response[:idx]
return response
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
REACT_INSTRUCTION = """Answer the following questions as best you can. You have acesss to the following APIs:
{tools_text}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tools_name_text}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!"""
_TEXT_COMPLETION_CMD = object()
#
# Temporarily, the system role does not work as expected.
# We advise that you write the setups for role-play in your query,
# i.e., use the user role instead of the system role.
#
# TODO: Use real system role when the model is ready.
#
def parse_messages(messages, functions):
if all(m.role != "user" for m in messages):
raise HTTPException(
status_code=400,
detail=f"Invalid request: Expecting at least one user message.",
)
messages = copy.deepcopy(messages)
default_system = "You are a helpful assistant."
system = ""
if messages[0].role == "system":
system = messages.pop(0).content.lstrip("\n").rstrip()
if system == default_system:
system = ""
if functions:
tools_text = []
tools_name_text = []
for func_info in functions:
name = func_info.get("name", "")
name_m = func_info.get("name_for_model", name)
name_h = func_info.get("name_for_human", name)
desc = func_info.get("description", "")
desc_m = func_info.get("description_for_model", desc)
tool = TOOL_DESC.format(
name_for_model=name_m,
name_for_human=name_h,
# Hint: You can add the following format requirements in description:
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model=desc_m,
parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
)
tools_text.append(tool)
tools_name_text.append(name_m)
tools_text = "\n\n".join(tools_text)
tools_name_text = ", ".join(tools_name_text)
system += "\n\n" + REACT_INSTRUCTION.format(
tools_text=tools_text,
tools_name_text=tools_name_text,
)
system = system.lstrip("\n").rstrip()
dummy_thought = {
"en": "\nThought: I now know the final answer.\nFinal answer: ",
"zh": "\nThought: 我会作答了。\nFinal answer: ",
}
_messages = messages
messages = []
for m_idx, m in enumerate(_messages):
role, content, func_call = m.role, m.content, m.function_call
if content:
content = content.lstrip("\n").rstrip()
if role == "function":
if (len(messages) == 0) or (messages[-1].role != "assistant"):
raise HTTPException(
status_code=400,
detail=f"Invalid request: Expecting role assistant before role function.",
)
messages[-1].content += f"\nObservation: {content}"
if m_idx == len(_messages) - 1:
messages[-1].content += "\nThought:"
elif role == "assistant":
if len(messages) == 0:
raise HTTPException(
status_code=400,
detail=f"Invalid request: Expecting role user before role assistant.",
)
last_msg = messages[-1].content
last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
if func_call is None:
if functions:
content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
else:
f_name, f_args = func_call["name"], func_call["arguments"]
if not content:
if last_msg_has_zh:
content = f"Thought: 我可以使用 {f_name} API。"
else:
content = f"Thought: I can use {f_name}."
content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
if messages[-1].role == "user":
messages.append(
ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
)
else:
messages[-1].content += content
elif role == "user":
messages.append(
ChatMessage(role="user", content=content.lstrip("\n").rstrip())
)
else:
raise HTTPException(
status_code=400, detail=f"Invalid request: Incorrect role {role}."
)
query = _TEXT_COMPLETION_CMD
if messages[-1].role == "user":
query = messages[-1].content
messages = messages[:-1]
if len(messages) % 2 != 0:
raise HTTPException(status_code=400, detail="Invalid request")
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range(0, len(messages), 2):
if messages[i].role == "user" and messages[i + 1].role == "assistant":
usr_msg = messages[i].content.lstrip("\n").rstrip()
bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
if system and (i == len(messages) - 2):
usr_msg = f"{system}\n\nQuestion: {usr_msg}"
system = ""
for t in dummy_thought.values():
t = t.lstrip("\n")
if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
bot_msg = bot_msg[len(t) :]
history.append([usr_msg, bot_msg])
else:
raise HTTPException(
status_code=400,
detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
)
if system:
assert query is not _TEXT_COMPLETION_CMD
query = f"{system}\n\nQuestion: {query}"
return query, history
def parse_response(response):
func_name, func_args = "", ""
i = response.rfind("\nAction:")
j = response.rfind("\nAction Input:")
k = response.rfind("\nObservation:")
if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
response = response.rstrip() + "\nObservation:" # Add it back.
k = response.rfind("\nObservation:")
func_name = response[i + len("\nAction:") : j].strip()
func_args = response[j + len("\nAction Input:") : k].strip()
if func_name:
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=response[:i],
function_call={"name": func_name, "arguments": func_args},
),
finish_reason="function_call",
)
return choice_data
z = response.rfind("\nFinal Answer: ")
if z >= 0:
response = response[z + len("\nFinal Answer: ") :]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
return choice_data
# completion mode, not chat mode
def text_complete_last_message(history, stop_words_ids):
im_start = "<|im_start|>"
im_end = "<|im_end|>"
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
for i, (query, response) in enumerate(history):
query = query.lstrip("\n").rstrip()
response = response.lstrip("\n").rstrip()
prompt += f"\n{im_start}user\n{query}{im_end}"
prompt += f"\n{im_start}assistant\n{response}{im_end}"
prompt = prompt[: -len(im_end)]
_stop_words_ids = [tokenizer.encode(im_end)]
if stop_words_ids:
for s in stop_words_ids:
_stop_words_ids.append(s)
stop_words_ids = _stop_words_ids
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0]
output = tokenizer.decode(output, errors="ignore")
assert output.startswith(prompt)
output = output[len(prompt) :]
output = trim_stop_words(output, ["<|endoftext|>", im_end])
print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
return output
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
stop_words = request.stop
stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words))))
prev_messages = request.messages[:-1]
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
# if len(prev_messages) > 0 and prev_messages[0].role == "system":
# query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=400, detail="Invalid request.")
else:
raise HTTPException(status_code=400, detail="Invalid request.")
stop_words = add_extra_stop_words(request.stop)
if request.functions:
stop_words = stop_words or []
if "Observation:" not in stop_words:
stop_words.append("Observation:")
query, history = parse_messages(request.messages, request.functions)
if request.stream:
if request.functions:
raise HTTPException(
status_code=400,
detail="Invalid request: Function calling is not yet implemented for stream mode.",
)
generate = predict(query, history, request.model, stop_words)
return EventSourceResponse(generate, media_type="text/event-stream")
if stop_words:
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
for stop_ in stop_words:
if response.endswith(stop_):
response = response[:response.find(stop_)]
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
if query is _TEXT_COMPLETION_CMD:
response = text_complete_last_message(history, stop_words_ids=stop_words_ids)
else:
response, _ = model.chat(
tokenizer,
query,
history=history,
stop_words_ids=stop_words_ids,
append_history=False,
)
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
response = trim_stop_words(response, stop_words)
if request.functions:
choice_data = parse_response(response)
else:
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
finish_reason="stop",
)
return ChatCompletionResponse(
model=request.model, choices=[choice_data], object="chat.completion"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]):
async def predict(
query: str, history: List[List[str]], model_id: str, stop_words: List[str]
):
global model, tokenizer
assert stop_words == [], "in stream format, stop word is output"
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
current_length = 0
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
if stop_words:
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
else:
response_generator = model.chat_stream(tokenizer, query, history=history)
# TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException(
status_code=400,
detail="Invalid request: custom stop words are not yet supported for stream mode.",
)
response_generator = model.chat_stream(
tokenizer, query, history=history, stop_words_ids=stop_words_ids
)
for new_response in response_generator:
if len(new_response) == current_length:
continue
@ -169,32 +426,41 @@ async def predict(query: str, history: List[List[str]], model_id: str, stop_word
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
index=0, delta=DeltaMessage(), finish_reason="stop"
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
yield "[DONE]"
def _get_args():
parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat',
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--server-port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
default="QWen/QWen-7B-Chat",
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only"
)
parser.add_argument(
"--server-port", type=int, default=8000, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
)
args = parser.parse_args()
return args
@ -204,7 +470,9 @@ if __name__ == "__main__":
args = _get_args()
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
args.checkpoint_path,
trust_remote_code=True,
resume_download=True,
)
if args.cpu_only:
@ -220,7 +488,9 @@ if __name__ == "__main__":
).eval()
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
args.checkpoint_path,
trust_remote_code=True,
resume_download=True,
)
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)

Loading…
Cancel
Save