update openai_api.py

main
yangapku 1 year ago
parent 29fea23f87
commit ab109ced9f

@ -1,4 +1,6 @@
# Reference: https://openai.com/blog/function-calling-and-other-api-updates
import json
from pprint import pprint
import openai
@ -9,216 +11,223 @@ import openai
# python openai_api.py;
#
# Then configure the api_base and api_key in your client:
openai.api_base = "http://localhost:8000/v1"
openai.api_key = "none"
openai.api_base = 'http://localhost:8000/v1'
openai.api_key = 'none'
def call_qwen(messages, functions=None):
print(messages)
print('input:')
pprint(messages, indent=2)
if functions:
response = openai.ChatCompletion.create(
model="Qwen", messages=messages, functions=functions
)
response = openai.ChatCompletion.create(model='Qwen',
messages=messages,
functions=functions)
else:
response = openai.ChatCompletion.create(model="Qwen", messages=messages)
print(response)
print(response.choices[0].message.content)
response = openai.ChatCompletion.create(model='Qwen',
messages=messages)
response = response.choices[0]['message']
response = json.loads(json.dumps(response,
ensure_ascii=False)) # fix zh rendering
print('output:')
pprint(response, indent=2)
print()
return response
def test_1():
messages = [{"role": "user", "content": "你好"}]
messages = [{'role': 'user', 'content': '你好'}]
call_qwen(messages)
messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"})
messages.append({'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'})
messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"})
messages.append({
'role': 'user',
'content': '给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。'
})
call_qwen(messages)
messages.append(
{
"role": "assistant",
"content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……",
}
)
messages.append({"role": "user", "content": "给这个故事起一个标题"})
messages.append({
'role':
'assistant',
'content':
'故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……',
})
messages.append({'role': 'user', 'content': '给这个故事起一个标题'})
call_qwen(messages)
def test_2():
functions = [
{
"name_for_human": "谷歌搜索",
"name_for_model": "google_search",
"description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。"
+ " Format the arguments as a JSON object.",
"parameters": [
{
"name": "search_query",
"description": "搜索关键词或短语",
"required": True,
"schema": {"type": "string"},
}
],
'name_for_human':
'谷歌搜索',
'name_for_model':
'google_search',
'description_for_model':
'谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。' +
' Format the arguments as a JSON object.',
'parameters': [{
'name': 'search_query',
'description': '搜索关键词或短语',
'required': True,
'schema': {
'type': 'string'
},
}],
},
{
"name_for_human": "文生图",
"name_for_model": "image_gen",
"description_for_model": "文生图是一个AI绘画图像生成服务输入文本描述返回根据文本作画得到的图片的URL。"
+ " Format the arguments as a JSON object.",
"parameters": [
{
"name": "prompt",
"description": "英文关键词,描述了希望图像具有什么内容",
"required": True,
"schema": {"type": "string"},
}
],
'name_for_human':
'文生图',
'name_for_model':
'image_gen',
'description_for_model':
'文生图是一个AI绘画图像生成服务输入文本描述返回根据文本作画得到的图片的URL。' +
' Format the arguments as a JSON object.',
'parameters': [{
'name': 'prompt',
'description': '英文关键词,描述了希望图像具有什么内容',
'required': True,
'schema': {
'type': 'string'
},
}],
},
]
messages = [{"role": "user", "content": "你好"}]
messages = [{'role': 'user', 'content': '(请不要调用工具)\n\n你好'}]
call_qwen(messages, functions)
messages.append(
{"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"},
)
messages.append({
'role': 'assistant',
'content': '你好!很高兴见到你。有什么我可以帮忙的吗?'
}, )
messages.append({"role": "user", "content": "谁是周杰伦"})
messages.append({'role': 'user', 'content': '搜索一下谁是周杰伦'})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "Thought: 我应该使用Google搜索查找相关信息。",
"function_call": {
"name": "google_search",
"arguments": '{"search_query": "周杰伦"}',
messages.append({
'role': 'assistant',
'content': '我应该使用Google搜索查找相关信息。',
'function_call': {
'name': 'google_search',
'arguments': '{"search_query": "周杰伦"}',
},
}
)
})
messages.append(
{
"role": "function",
"name": "google_search",
"content": "Jay Chou is a Taiwanese singer.",
}
)
messages.append({
'role': 'function',
'name': 'google_search',
'content': 'Jay Chou is a Taiwanese singer.',
})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "周杰伦Jay Chou是一位来自台湾的歌手。",
},
)
'role': 'assistant',
'content': '周杰伦Jay Chou是一位来自台湾的歌手。',
}, )
messages.append({"role": "user", "content": "他老婆是谁"})
messages.append({'role': 'user', 'content': '搜索一下他老婆是谁'})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "Thought: 我应该使用Google搜索查找相关信息。",
"function_call": {
"name": "google_search",
"arguments": '{"search_query": "周杰伦 老婆"}',
messages.append({
'role': 'assistant',
'content': '我应该使用Google搜索查找相关信息。',
'function_call': {
'name': 'google_search',
'arguments': '{"search_query": "周杰伦 老婆"}',
},
}
)
})
messages.append(
{"role": "function", "name": "google_search", "content": "Hannah Quinlivan"}
)
messages.append({
'role': 'function',
'name': 'google_search',
'content': 'Hannah Quinlivan'
})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "周杰伦的老婆是Hannah Quinlivan。",
},
)
'role': 'assistant',
'content': '周杰伦的老婆是Hannah Quinlivan。',
}, )
messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"})
messages.append({'role': 'user', 'content': '用文生图工具画个可爱的小猫吧,最好是黑猫'})
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。",
"function_call": {
"name": "image_gen",
"arguments": '{"prompt": "cute black cat"}',
messages.append({
'role': 'assistant',
'content': '我应该使用文生图API来生成一张可爱的小猫图片。',
'function_call': {
'name': 'image_gen',
'arguments': '{"prompt": "cute black cat"}',
},
}
)
messages.append(
{
"role": "function",
"name": "image_gen",
"content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}',
}
)
})
messages.append({
'role':
'function',
'name':
'image_gen',
'content':
'{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}',
})
call_qwen(messages, functions)
def test_3():
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
functions = [{
'name': 'get_current_weather',
'description': 'Get the current weather in a given location.',
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description':
'The city and state, e.g. San Francisco, CA',
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit']
},
"required": ["location"],
},
}
]
'required': ['location'],
},
}]
messages = [
{
"role": "user",
messages = [{
'role': 'user',
# Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts,
# but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting.
"content": "波士顿天气如何?",
}
]
'content': '波士顿天气如何?',
}]
call_qwen(messages, functions)
messages.append(
{
"role": "assistant",
"content": None,
"function_call": {
"name": "get_current_weather",
"arguments": '{"location": "Boston, MA"}',
},
'role': 'assistant',
'content': None,
'function_call': {
'name': 'get_current_weather',
'arguments': '{"location": "Boston, MA"}',
},
)
messages.append(
{
"role": "function",
"name": "get_current_weather",
"content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}',
}
)
}, )
messages.append({
'role':
'function',
'name':
'get_current_weather',
'content':
'{"temperature": "22", "unit": "celsius", "description": "Sunny"}',
})
call_qwen(messages, functions)
def test_4():
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.chat_models import ChatOpenAI
from langchain.agents import load_tools, initialize_agent, AgentType
llm = ChatOpenAI(
model_name="Qwen",
openai_api_base="http://localhost:8000/v1",
openai_api_key="EMPTY",
model_name='Qwen',
openai_api_base='http://localhost:8000/v1',
openai_api_key='EMPTY',
streaming=False,
)
tools = load_tools(
["arxiv"],
)
tools = load_tools(['arxiv'], )
agent_chain = initialize_agent(
tools,
llm,
@ -226,15 +235,15 @@ def test_4():
verbose=True,
)
# TODO: The performance is okay with Chinese prompts, but not so good when it comes to English.
agent_chain.run("查一下论文 1605.08386 的信息")
agent_chain.run('查一下论文 1605.08386 的信息')
if __name__ == "__main__":
print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###")
if __name__ == '__main__':
print('### Test Case 1 - No Function Calling (普通问答、无函数调用) ###')
test_1()
print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###")
print('### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###')
test_2()
print("### Test Case 3 - Use GPT-Style Functions (函数调用GPT格式) ###")
print('### Test Case 3 - Use GPT-Style Functions (函数调用GPT格式) ###')
test_3()
print("### Test Case 4 - Use LangChain (接入Langchain) ###")
print('### Test Case 4 - Use LangChain (接入Langchain) ###')
test_4()

@ -1,14 +1,16 @@
# coding=utf-8
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Requirement:
# pip install "openai<1.0"
# Usage:
# python openai_api.py
# Visit http://localhost:8000/docs for documents.
import re
import base64
import copy
import json
import time
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from pprint import pprint
from typing import Dict, List, Literal, Optional, Union
import torch
@ -17,20 +19,22 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import base64
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
self.required_credentials = base64.b64encode(
f'{username}:{password}'.encode()).decode()
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization")
authorization: str = request.headers.get('Authorization')
if authorization:
try:
schema, credentials = authorization.split()
@ -42,12 +46,14 @@ class BasicAuthMiddleware(BaseHTTPMiddleware):
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False):
global args
if args.disable_gc and not forced:
return
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -63,36 +69,36 @@ app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=['*'],
allow_headers=['*'],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
object: str = 'model'
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
owned_by: str = 'owner'
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
object: str = 'list'
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"]
role: Literal['user', 'assistant', 'system', 'function']
content: Optional[str]
function_call: Optional[Dict] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
role: Optional[Literal['user', 'assistant', 'system']] = None
content: Optional[str] = None
@ -102,6 +108,7 @@ class ChatCompletionRequest(BaseModel):
functions: Optional[List[Dict]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
stop: Optional[List[str]] = None
@ -109,29 +116,28 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
message: Union[ChatMessage]
finish_reason: Literal['stop', 'length', 'function_call']
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
finish_reason: Optional[Literal['stop', 'length']]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
]
object: Literal['chat.completion', 'chat.completion.chunk']
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
@app.get('/v1/models', response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
model_card = ModelCard(id='gpt-3.5-turbo')
return ModelList(data=[model_card])
@ -141,7 +147,7 @@ def add_extra_stop_words(stop_words):
_stop_words = []
_stop_words.extend(stop_words)
for x in stop_words:
s = x.lstrip("\n")
s = x.lstrip('\n')
if s and (s not in _stop_words):
_stop_words.append(s)
return _stop_words
@ -157,7 +163,10 @@ def trim_stop_words(response, stop_words):
return response
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
TOOL_DESC = (
'{name_for_model}: Call this tool to interact with the {name_for_human} API.'
' What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}'
)
REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
@ -179,37 +188,28 @@ Begin!"""
_TEXT_COMPLETION_CMD = object()
#
# Temporarily, the system role does not work as expected.
# We advise that you write the setups for role-play in your query,
# i.e., use the user role instead of the system role.
#
# TODO: Use real system role when the model is ready.
#
def parse_messages(messages, functions):
if all(m.role != "user" for m in messages):
if all(m.role != 'user' for m in messages):
raise HTTPException(
status_code=400,
detail=f"Invalid request: Expecting at least one user message.",
detail='Invalid request: Expecting at least one user message.',
)
messages = copy.deepcopy(messages)
default_system = "You are a helpful assistant."
system = ""
if messages[0].role == "system":
system = messages.pop(0).content.lstrip("\n").rstrip()
if system == default_system:
system = ""
if messages[0].role == 'system':
system = messages.pop(0).content.lstrip('\n').rstrip()
else:
system = 'You are a helpful assistant.'
if functions:
tools_text = []
tools_name_text = []
for func_info in functions:
name = func_info.get("name", "")
name_m = func_info.get("name_for_model", name)
name_h = func_info.get("name_for_human", name)
desc = func_info.get("description", "")
desc_m = func_info.get("description_for_model", desc)
name = func_info.get('name', '')
name_m = func_info.get('name_for_model', name)
name_h = func_info.get('name_for_human', name)
desc = func_info.get('description', '')
desc_m = func_info.get('description_for_model', desc)
tool = TOOL_DESC.format(
name_for_model=name_m,
name_for_human=name_h,
@ -217,149 +217,151 @@ def parse_messages(messages, functions):
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model=desc_m,
parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
parameters=json.dumps(func_info['parameters'],
ensure_ascii=False),
)
tools_text.append(tool)
tools_name_text.append(name_m)
tools_text = "\n\n".join(tools_text)
tools_name_text = ", ".join(tools_name_text)
system += "\n\n" + REACT_INSTRUCTION.format(
tools_text = '\n\n'.join(tools_text)
tools_name_text = ', '.join(tools_name_text)
instruction = (REACT_INSTRUCTION.format(
tools_text=tools_text,
tools_name_text=tools_name_text,
)
system = system.lstrip("\n").rstrip()
dummy_thought = {
"en": "\nThought: I now know the final answer.\nFinal answer: ",
"zh": "\nThought: 我会作答了。\nFinal answer: ",
}
).lstrip('\n').rstrip())
else:
instruction = ''
_messages = messages
messages_with_fncall = messages
messages = []
for m_idx, m in enumerate(_messages):
for m_idx, m in enumerate(messages_with_fncall):
role, content, func_call = m.role, m.content, m.function_call
if content:
content = content.lstrip("\n").rstrip()
if role == "function":
if (len(messages) == 0) or (messages[-1].role != "assistant"):
content = content or ''
content = content.lstrip('\n').rstrip()
if role == 'function':
if (len(messages) == 0) or (messages[-1].role != 'assistant'):
raise HTTPException(
status_code=400,
detail=f"Invalid request: Expecting role assistant before role function.",
)
messages[-1].content += f"\nObservation: {content}"
if m_idx == len(_messages) - 1:
messages[-1].content += "\nThought:"
elif role == "assistant":
detail=
'Invalid request: Expecting role assistant before role function.',
)
messages[-1].content += f'\nObservation: {content}'
if m_idx == len(messages_with_fncall) - 1:
# add a prefix for text completion
messages[-1].content += '\nThought:'
elif role == 'assistant':
if len(messages) == 0:
raise HTTPException(
status_code=400,
detail=f"Invalid request: Expecting role user before role assistant.",
detail=
'Invalid request: Expecting role user before role assistant.',
)
last_msg = messages[-1].content
last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
if func_call is None:
if functions:
content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
else:
f_name, f_args = func_call["name"], func_call["arguments"]
if not content:
if last_msg_has_zh:
content = f"Thought: 我可以使用 {f_name} API。"
content = f'Thought: I now know the final answer.\nFinal Answer: {content}'
else:
content = f"Thought: I can use {f_name}."
content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
if messages[-1].role == "user":
f_name, f_args = func_call['name'], func_call['arguments']
if not content.startswith('Thought:'):
content = f'Thought: {content}'
content = f'{content}\nAction: {f_name}\nAction Input: {f_args}'
if messages[-1].role == 'user':
messages.append(
ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
)
ChatMessage(role='assistant',
content=content.lstrip('\n').rstrip()))
else:
messages[-1].content += content
elif role == "user":
messages[-1].content += '\n' + content
elif role == 'user':
messages.append(
ChatMessage(role="user", content=content.lstrip("\n").rstrip())
)
ChatMessage(role='user',
content=content.lstrip('\n').rstrip()))
else:
raise HTTPException(
status_code=400, detail=f"Invalid request: Incorrect role {role}."
)
status_code=400,
detail=f'Invalid request: Incorrect role {role}.')
query = _TEXT_COMPLETION_CMD
if messages[-1].role == "user":
if messages[-1].role == 'user':
query = messages[-1].content
messages = messages[:-1]
if len(messages) % 2 != 0:
raise HTTPException(status_code=400, detail="Invalid request")
raise HTTPException(status_code=400, detail='Invalid request')
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range(0, len(messages), 2):
if messages[i].role == "user" and messages[i + 1].role == "assistant":
usr_msg = messages[i].content.lstrip("\n").rstrip()
bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
if system and (i == len(messages) - 2):
usr_msg = f"{system}\n\nQuestion: {usr_msg}"
system = ""
for t in dummy_thought.values():
t = t.lstrip("\n")
if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
bot_msg = bot_msg[len(t) :]
if messages[i].role == 'user' and messages[i + 1].role == 'assistant':
usr_msg = messages[i].content.lstrip('\n').rstrip()
bot_msg = messages[i + 1].content.lstrip('\n').rstrip()
if instruction and (i == len(messages) - 2):
usr_msg = f'{instruction}\n\nQuestion: {usr_msg}'
instruction = ''
history.append([usr_msg, bot_msg])
else:
raise HTTPException(
status_code=400,
detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
detail=
'Invalid request: Expecting exactly one user (or function) role before every assistant role.',
)
if system:
if instruction:
assert query is not _TEXT_COMPLETION_CMD
query = f"{system}\n\nQuestion: {query}"
return query, history
query = f'{instruction}\n\nQuestion: {query}'
return query, history, system
def parse_response(response):
func_name, func_args = "", ""
i = response.rfind("\nAction:")
j = response.rfind("\nAction Input:")
k = response.rfind("\nObservation:")
func_name, func_args = '', ''
i = response.find('\nAction:')
j = response.find('\nAction Input:')
k = response.find('\nObservation:')
if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
response = response.rstrip() + "\nObservation:" # Add it back.
k = response.rfind("\nObservation:")
func_name = response[i + len("\nAction:") : j].strip()
func_args = response[j + len("\nAction Input:") : k].strip()
response = response.rstrip() + '\nObservation:' # Add it back.
k = response.find('\nObservation:')
func_name = response[i + len('\nAction:'):j].strip()
func_args = response[j + len('\nAction Input:'):k].strip()
if func_name:
response = response[:i]
t = response.find('Thought: ')
if t >= 0:
response = response[t + len('Thought: '):]
response = response.strip()
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=response[:i],
function_call={"name": func_name, "arguments": func_args},
role='assistant',
content=response,
function_call={
'name': func_name,
'arguments': func_args
},
),
finish_reason="function_call",
finish_reason='function_call',
)
return choice_data
z = response.rfind("\nFinal Answer: ")
z = response.rfind('\nFinal Answer: ')
if z >= 0:
response = response[z + len("\nFinal Answer: ") :]
response = response[z + len('\nFinal Answer: '):]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
message=ChatMessage(role='assistant', content=response),
finish_reason='stop',
)
return choice_data
# completion mode, not chat mode
def text_complete_last_message(history, stop_words_ids, gen_kwargs):
im_start = "<|im_start|>"
im_end = "<|im_end|>"
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
def text_complete_last_message(history, stop_words_ids, gen_kwargs, system):
im_start = '<|im_start|>'
im_end = '<|im_end|>'
prompt = f'{im_start}system\n{system}{im_end}'
for i, (query, response) in enumerate(history):
query = query.lstrip("\n").rstrip()
response = response.lstrip("\n").rstrip()
prompt += f"\n{im_start}user\n{query}{im_end}"
prompt += f"\n{im_start}assistant\n{response}{im_end}"
query = query.lstrip('\n').rstrip()
response = response.lstrip('\n').rstrip()
prompt += f'\n{im_start}user\n{query}{im_end}'
prompt += f'\n{im_start}assistant\n{response}{im_end}'
prompt = prompt[:-len(im_end)]
_stop_words_ids = [tokenizer.encode(im_end)]
@ -369,20 +371,24 @@ def text_complete_last_message(history, stop_words_ids, gen_kwargs):
stop_words_ids = _stop_words_ids
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
output = tokenizer.decode(output, errors="ignore")
output = model.generate(input_ids,
stop_words_ids=stop_words_ids,
**gen_kwargs).tolist()[0]
output = tokenizer.decode(output, errors='ignore')
assert output.startswith(prompt)
output = output[len(prompt):]
output = trim_stop_words(output, ["<|endoftext|>", im_end])
print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
output = trim_stop_words(output, ['<|endoftext|>', im_end])
print(f'<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>')
return output
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
gen_kwargs = {}
if request.top_k is not None:
gen_kwargs['top_k'] = request.top_k
if request.temperature is not None:
if request.temperature < 0.01:
gen_kwargs['top_k'] = 1 # greedy decoding
@ -395,32 +401,46 @@ async def create_chat_completion(request: ChatCompletionRequest):
stop_words = add_extra_stop_words(request.stop)
if request.functions:
stop_words = stop_words or []
if "Observation:" not in stop_words:
stop_words.append("Observation:")
if 'Observation:' not in stop_words:
stop_words.append('Observation:')
query, history = parse_messages(request.messages, request.functions)
query, history, system = parse_messages(request.messages,
request.functions)
if request.stream:
if request.functions:
raise HTTPException(
status_code=400,
detail="Invalid request: Function calling is not yet implemented for stream mode.",
)
generate = predict(query, history, request.model, stop_words, gen_kwargs)
return EventSourceResponse(generate, media_type="text/event-stream")
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
detail=
'Invalid request: Function calling is not yet implemented for stream mode.',
)
generate = predict(query,
history,
request.model,
stop_words,
gen_kwargs,
system=system)
return EventSourceResponse(generate, media_type='text/event-stream')
stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None
if query is _TEXT_COMPLETION_CMD:
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
response = text_complete_last_message(history,
stop_words_ids=stop_words_ids,
gen_kwargs=gen_kwargs,
system=system)
else:
response, _ = model.chat(
tokenizer,
query,
history=history,
system=system,
stop_words_ids=stop_words_ids,
**gen_kwargs
**gen_kwargs,
)
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
print('<chat>')
pprint(history, indent=2)
print(f'{query}\n<!-- *** -->\n{response}\n</chat>')
_gc()
response = trim_stop_words(response, stop_words)
@ -429,12 +449,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
else:
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
return ChatCompletionResponse(
model=request.model, choices=[choice_data], object="chat.completion"
message=ChatMessage(role='assistant', content=response),
finish_reason='stop',
)
return ChatCompletionResponse(model=request.model,
choices=[choice_data],
object='chat.completion')
def _dump_json(data: BaseModel, *args, **kwargs) -> str:
@ -445,28 +465,37 @@ def _dump_json(data: BaseModel, *args, **kwargs) -> str:
async def predict(
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
query: str,
history: List[List[str]],
model_id: str,
stop_words: List[str],
gen_kwargs: Dict,
system: str,
):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
yield "{}".format(_dump_json(chunk, exclude_unset=True))
index=0, delta=DeltaMessage(role='assistant'), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
current_length = 0
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None
if stop_words:
# TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException(
status_code=400,
detail="Invalid request: custom stop words are not yet supported for stream mode.",
)
response_generator = model.chat_stream(
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
detail=
'Invalid request: custom stop words are not yet supported for stream mode.',
)
response_generator = model.chat_stream(tokenizer,
query,
history=history,
stop_words_ids=stop_words_ids,
system=system,
**gen_kwargs)
for new_response in response_generator:
if len(new_response) == current_length:
continue
@ -475,21 +504,20 @@ async def predict(
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
yield "{}".format(_dump_json(chunk, exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(), finish_reason="stop"
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
yield "{}".format(_dump_json(chunk, exclude_unset=True))
yield "[DONE]"
index=0, delta=DeltaMessage(content=new_text), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(index=0,
delta=DeltaMessage(),
finish_reason='stop')
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
yield '[DONE]'
_gc()
@ -497,36 +525,39 @@ async def predict(
def _get_args():
parser = ArgumentParser()
parser.add_argument(
"-c",
"--checkpoint-path",
'-c',
'--checkpoint-path',
type=str,
default="Qwen/Qwen-7B-Chat",
help="Checkpoint name or path, default to %(default)r",
)
default='Qwen/Qwen-7B-Chat',
help='Checkpoint name or path, default to %(default)r',
)
parser.add_argument('--api-auth', help='API authentication credentials')
parser.add_argument('--cpu-only',
action='store_true',
help='Run demo with CPU only')
parser.add_argument('--server-port',
type=int,
default=8000,
help='Demo server port.')
parser.add_argument(
"--api-auth", help="API authentication credentials"
)
parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only"
)
parser.add_argument(
"--server-port", type=int, default=8000, help="Demo server port."
'--server-name',
type=str,
default='127.0.0.1',
help=
'Demo server name. Default: 127.0.0.1, which is only visible from the local computer.'
' If you want other computers to access your server, use 0.0.0.0 instead.',
)
parser.add_argument(
"--server-name",
type=str,
default="127.0.0.1",
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
" If you want other computers to access your server, use 0.0.0.0 instead.",
'--disable-gc',
action='store_true',
help='Disable GC after each response generated.',
)
parser.add_argument("--disable-gc", action="store_true",
help="Disable GC after each response generated.")
args = parser.parse_args()
return args
if __name__ == "__main__":
if __name__ == '__main__':
args = _get_args()
tokenizer = AutoTokenizer.from_pretrained(
@ -536,14 +567,14 @@ if __name__ == "__main__":
)
if args.api_auth:
app.add_middleware(
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
)
app.add_middleware(BasicAuthMiddleware,
username=args.api_auth.split(':')[0],
password=args.api_auth.split(':')[1])
if args.cpu_only:
device_map = "cpu"
device_map = 'cpu'
else:
device_map = "auto"
device_map = 'auto'
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,

Loading…
Cancel
Save