update openai_api.py

main
yangapku 1 year ago
parent 29fea23f87
commit ab109ced9f

@ -1,4 +1,6 @@
# Reference: https://openai.com/blog/function-calling-and-other-api-updates # Reference: https://openai.com/blog/function-calling-and-other-api-updates
import json
from pprint import pprint
import openai import openai
@ -9,216 +11,223 @@ import openai
# python openai_api.py; # python openai_api.py;
# #
# Then configure the api_base and api_key in your client: # Then configure the api_base and api_key in your client:
openai.api_base = "http://localhost:8000/v1" openai.api_base = 'http://localhost:8000/v1'
openai.api_key = "none" openai.api_key = 'none'
def call_qwen(messages, functions=None): def call_qwen(messages, functions=None):
print(messages) print('input:')
pprint(messages, indent=2)
if functions: if functions:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(model='Qwen',
model="Qwen", messages=messages, functions=functions messages=messages,
) functions=functions)
else: else:
response = openai.ChatCompletion.create(model="Qwen", messages=messages) response = openai.ChatCompletion.create(model='Qwen',
print(response) messages=messages)
print(response.choices[0].message.content) response = response.choices[0]['message']
response = json.loads(json.dumps(response,
ensure_ascii=False)) # fix zh rendering
print('output:')
pprint(response, indent=2)
print()
return response return response
def test_1(): def test_1():
messages = [{"role": "user", "content": "你好"}] messages = [{'role': 'user', 'content': '你好'}]
call_qwen(messages) call_qwen(messages)
messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"}) messages.append({'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'})
messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"}) messages.append({
'role': 'user',
'content': '给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。'
})
call_qwen(messages) call_qwen(messages)
messages.append( messages.append({
{ 'role':
"role": "assistant", 'assistant',
"content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……", 'content':
} '故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……',
) })
messages.append({"role": "user", "content": "给这个故事起一个标题"}) messages.append({'role': 'user', 'content': '给这个故事起一个标题'})
call_qwen(messages) call_qwen(messages)
def test_2(): def test_2():
functions = [ functions = [
{ {
"name_for_human": "谷歌搜索", 'name_for_human':
"name_for_model": "google_search", '谷歌搜索',
"description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。" 'name_for_model':
+ " Format the arguments as a JSON object.", 'google_search',
"parameters": [ 'description_for_model':
{ '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。' +
"name": "search_query", ' Format the arguments as a JSON object.',
"description": "搜索关键词或短语", 'parameters': [{
"required": True, 'name': 'search_query',
"schema": {"type": "string"}, 'description': '搜索关键词或短语',
} 'required': True,
], 'schema': {
'type': 'string'
},
}],
}, },
{ {
"name_for_human": "文生图", 'name_for_human':
"name_for_model": "image_gen", '文生图',
"description_for_model": "文生图是一个AI绘画图像生成服务输入文本描述返回根据文本作画得到的图片的URL。" 'name_for_model':
+ " Format the arguments as a JSON object.", 'image_gen',
"parameters": [ 'description_for_model':
{ '文生图是一个AI绘画图像生成服务输入文本描述返回根据文本作画得到的图片的URL。' +
"name": "prompt", ' Format the arguments as a JSON object.',
"description": "英文关键词,描述了希望图像具有什么内容", 'parameters': [{
"required": True, 'name': 'prompt',
"schema": {"type": "string"}, 'description': '英文关键词,描述了希望图像具有什么内容',
} 'required': True,
], 'schema': {
'type': 'string'
},
}],
}, },
] ]
messages = [{"role": "user", "content": "你好"}] messages = [{'role': 'user', 'content': '(请不要调用工具)\n\n你好'}]
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append({
{"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"}, 'role': 'assistant',
) 'content': '你好!很高兴见到你。有什么我可以帮忙的吗?'
}, )
messages.append({"role": "user", "content": "谁是周杰伦"}) messages.append({'role': 'user', 'content': '搜索一下谁是周杰伦'})
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append({
{ 'role': 'assistant',
"role": "assistant", 'content': '我应该使用Google搜索查找相关信息。',
"content": "Thought: 我应该使用Google搜索查找相关信息。", 'function_call': {
"function_call": { 'name': 'google_search',
"name": "google_search", 'arguments': '{"search_query": "周杰伦"}',
"arguments": '{"search_query": "周杰伦"}',
}, },
} })
)
messages.append( messages.append({
{ 'role': 'function',
"role": "function", 'name': 'google_search',
"name": "google_search", 'content': 'Jay Chou is a Taiwanese singer.',
"content": "Jay Chou is a Taiwanese singer.", })
}
)
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append(
{ {
"role": "assistant", 'role': 'assistant',
"content": "周杰伦Jay Chou是一位来自台湾的歌手。", 'content': '周杰伦Jay Chou是一位来自台湾的歌手。',
}, }, )
)
messages.append({"role": "user", "content": "他老婆是谁"}) messages.append({'role': 'user', 'content': '搜索一下他老婆是谁'})
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append({
{ 'role': 'assistant',
"role": "assistant", 'content': '我应该使用Google搜索查找相关信息。',
"content": "Thought: 我应该使用Google搜索查找相关信息。", 'function_call': {
"function_call": { 'name': 'google_search',
"name": "google_search", 'arguments': '{"search_query": "周杰伦 老婆"}',
"arguments": '{"search_query": "周杰伦 老婆"}',
}, },
} })
)
messages.append( messages.append({
{"role": "function", "name": "google_search", "content": "Hannah Quinlivan"} 'role': 'function',
) 'name': 'google_search',
'content': 'Hannah Quinlivan'
})
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append(
{ {
"role": "assistant", 'role': 'assistant',
"content": "周杰伦的老婆是Hannah Quinlivan。", 'content': '周杰伦的老婆是Hannah Quinlivan。',
}, }, )
)
messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"}) messages.append({'role': 'user', 'content': '用文生图工具画个可爱的小猫吧,最好是黑猫'})
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append({
{ 'role': 'assistant',
"role": "assistant", 'content': '我应该使用文生图API来生成一张可爱的小猫图片。',
"content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。", 'function_call': {
"function_call": { 'name': 'image_gen',
"name": "image_gen", 'arguments': '{"prompt": "cute black cat"}',
"arguments": '{"prompt": "cute black cat"}',
}, },
} })
)
messages.append({
messages.append( 'role':
{ 'function',
"role": "function", 'name':
"name": "image_gen", 'image_gen',
"content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}', 'content':
} '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}',
) })
call_qwen(messages, functions) call_qwen(messages, functions)
def test_3(): def test_3():
functions = [ functions = [{
{ 'name': 'get_current_weather',
"name": "get_current_weather", 'description': 'Get the current weather in a given location.',
"description": "Get the current weather in a given location.", 'parameters': {
"parameters": { 'type': 'object',
"type": "object", 'properties': {
"properties": { 'location': {
"location": { 'type': 'string',
"type": "string", 'description':
"description": "The city and state, e.g. San Francisco, CA", 'The city and state, e.g. San Francisco, CA',
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, 'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit']
}, },
"required": ["location"],
}, },
} 'required': ['location'],
] },
}]
messages = [ messages = [{
{ 'role': 'user',
"role": "user",
# Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts,
# but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting.
"content": "波士顿天气如何?", 'content': '波士顿天气如何?',
} }]
]
call_qwen(messages, functions) call_qwen(messages, functions)
messages.append( messages.append(
{ {
"role": "assistant", 'role': 'assistant',
"content": None, 'content': None,
"function_call": { 'function_call': {
"name": "get_current_weather", 'name': 'get_current_weather',
"arguments": '{"location": "Boston, MA"}', 'arguments': '{"location": "Boston, MA"}',
},
}, },
) }, )
messages.append( messages.append({
{ 'role':
"role": "function", 'function',
"name": "get_current_weather", 'name':
"content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}', 'get_current_weather',
} 'content':
) '{"temperature": "22", "unit": "celsius", "description": "Sunny"}',
})
call_qwen(messages, functions) call_qwen(messages, functions)
def test_4(): def test_4():
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.agents import load_tools, initialize_agent, AgentType
llm = ChatOpenAI( llm = ChatOpenAI(
model_name="Qwen", model_name='Qwen',
openai_api_base="http://localhost:8000/v1", openai_api_base='http://localhost:8000/v1',
openai_api_key="EMPTY", openai_api_key='EMPTY',
streaming=False, streaming=False,
) )
tools = load_tools( tools = load_tools(['arxiv'], )
["arxiv"],
)
agent_chain = initialize_agent( agent_chain = initialize_agent(
tools, tools,
llm, llm,
@ -226,15 +235,15 @@ def test_4():
verbose=True, verbose=True,
) )
# TODO: The performance is okay with Chinese prompts, but not so good when it comes to English. # TODO: The performance is okay with Chinese prompts, but not so good when it comes to English.
agent_chain.run("查一下论文 1605.08386 的信息") agent_chain.run('查一下论文 1605.08386 的信息')
if __name__ == "__main__": if __name__ == '__main__':
print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") print('### Test Case 1 - No Function Calling (普通问答、无函数调用) ###')
test_1() test_1()
print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") print('### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###')
test_2() test_2()
print("### Test Case 3 - Use GPT-Style Functions (函数调用GPT格式) ###") print('### Test Case 3 - Use GPT-Style Functions (函数调用GPT格式) ###')
test_3() test_3()
print("### Test Case 4 - Use LangChain (接入Langchain) ###") print('### Test Case 4 - Use LangChain (接入Langchain) ###')
test_4() test_4()

@ -1,14 +1,16 @@
# coding=utf-8 # Requirement:
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # pip install "openai<1.0"
# Usage: python openai_api.py # Usage:
# python openai_api.py
# Visit http://localhost:8000/docs for documents. # Visit http://localhost:8000/docs for documents.
import re import base64
import copy import copy
import json import json
import time import time
from argparse import ArgumentParser from argparse import ArgumentParser
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pprint import pprint
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
@ -17,20 +19,22 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
import base64 from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
class BasicAuthMiddleware(BaseHTTPMiddleware): class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str): def __init__(self, app, username: str, password: str):
super().__init__(app) super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode() self.required_credentials = base64.b64encode(
f'{username}:{password}'.encode()).decode()
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization") authorization: str = request.headers.get('Authorization')
if authorization: if authorization:
try: try:
schema, credentials = authorization.split() schema, credentials = authorization.split()
@ -42,12 +46,14 @@ class BasicAuthMiddleware(BaseHTTPMiddleware):
headers = {'WWW-Authenticate': 'Basic'} headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers) return Response(status_code=401, headers=headers)
def _gc(forced: bool = False): def _gc(forced: bool = False):
global args global args
if args.disable_gc and not forced: if args.disable_gc and not forced:
return return
import gc import gc
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -63,36 +69,36 @@ app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=['*'],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=['*'],
allow_headers=["*"], allow_headers=['*'],
) )
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: str = "model" object: str = 'model'
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner" owned_by: str = 'owner'
root: Optional[str] = None root: Optional[str] = None
parent: Optional[str] = None parent: Optional[str] = None
permission: Optional[list] = None permission: Optional[list] = None
class ModelList(BaseModel): class ModelList(BaseModel):
object: str = "list" object: str = 'list'
data: List[ModelCard] = [] data: List[ModelCard] = []
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"] role: Literal['user', 'assistant', 'system', 'function']
content: Optional[str] content: Optional[str]
function_call: Optional[Dict] = None function_call: Optional[Dict] = None
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None role: Optional[Literal['user', 'assistant', 'system']] = None
content: Optional[str] = None content: Optional[str] = None
@ -102,6 +108,7 @@ class ChatCompletionRequest(BaseModel):
functions: Optional[List[Dict]] = None functions: Optional[List[Dict]] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
top_k: Optional[int] = None
max_length: Optional[int] = None max_length: Optional[int] = None
stream: Optional[bool] = False stream: Optional[bool] = False
stop: Optional[List[str]] = None stop: Optional[List[str]] = None
@ -109,29 +116,28 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: Union[ChatMessage]
finish_reason: Literal["stop", "length", "function_call"] finish_reason: Literal['stop', 'length', 'function_call']
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] finish_reason: Optional[Literal['stop', 'length']]
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
model: str model: str
object: Literal["chat.completion", "chat.completion.chunk"] object: Literal['chat.completion', 'chat.completion.chunk']
choices: List[ choices: List[Union[ChatCompletionResponseChoice,
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] ChatCompletionResponseStreamChoice]]
]
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList) @app.get('/v1/models', response_model=ModelList)
async def list_models(): async def list_models():
global model_args global model_args
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id='gpt-3.5-turbo')
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@ -141,7 +147,7 @@ def add_extra_stop_words(stop_words):
_stop_words = [] _stop_words = []
_stop_words.extend(stop_words) _stop_words.extend(stop_words)
for x in stop_words: for x in stop_words:
s = x.lstrip("\n") s = x.lstrip('\n')
if s and (s not in _stop_words): if s and (s not in _stop_words):
_stop_words.append(s) _stop_words.append(s)
return _stop_words return _stop_words
@ -157,7 +163,10 @@ def trim_stop_words(response, stop_words):
return response return response
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" TOOL_DESC = (
'{name_for_model}: Call this tool to interact with the {name_for_human} API.'
' What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}'
)
REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
@ -179,37 +188,28 @@ Begin!"""
_TEXT_COMPLETION_CMD = object() _TEXT_COMPLETION_CMD = object()
#
# Temporarily, the system role does not work as expected.
# We advise that you write the setups for role-play in your query,
# i.e., use the user role instead of the system role.
#
# TODO: Use real system role when the model is ready.
#
def parse_messages(messages, functions): def parse_messages(messages, functions):
if all(m.role != "user" for m in messages): if all(m.role != 'user' for m in messages):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid request: Expecting at least one user message.", detail='Invalid request: Expecting at least one user message.',
) )
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
default_system = "You are a helpful assistant." if messages[0].role == 'system':
system = "" system = messages.pop(0).content.lstrip('\n').rstrip()
if messages[0].role == "system": else:
system = messages.pop(0).content.lstrip("\n").rstrip() system = 'You are a helpful assistant.'
if system == default_system:
system = ""
if functions: if functions:
tools_text = [] tools_text = []
tools_name_text = [] tools_name_text = []
for func_info in functions: for func_info in functions:
name = func_info.get("name", "") name = func_info.get('name', '')
name_m = func_info.get("name_for_model", name) name_m = func_info.get('name_for_model', name)
name_h = func_info.get("name_for_human", name) name_h = func_info.get('name_for_human', name)
desc = func_info.get("description", "") desc = func_info.get('description', '')
desc_m = func_info.get("description_for_model", desc) desc_m = func_info.get('description_for_model', desc)
tool = TOOL_DESC.format( tool = TOOL_DESC.format(
name_for_model=name_m, name_for_model=name_m,
name_for_human=name_h, name_for_human=name_h,
@ -217,149 +217,151 @@ def parse_messages(messages, functions):
# "Format the arguments as a JSON object." # "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code." # "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model=desc_m, description_for_model=desc_m,
parameters=json.dumps(func_info["parameters"], ensure_ascii=False), parameters=json.dumps(func_info['parameters'],
ensure_ascii=False),
) )
tools_text.append(tool) tools_text.append(tool)
tools_name_text.append(name_m) tools_name_text.append(name_m)
tools_text = "\n\n".join(tools_text) tools_text = '\n\n'.join(tools_text)
tools_name_text = ", ".join(tools_name_text) tools_name_text = ', '.join(tools_name_text)
system += "\n\n" + REACT_INSTRUCTION.format( instruction = (REACT_INSTRUCTION.format(
tools_text=tools_text, tools_text=tools_text,
tools_name_text=tools_name_text, tools_name_text=tools_name_text,
) ).lstrip('\n').rstrip())
system = system.lstrip("\n").rstrip() else:
instruction = ''
dummy_thought = {
"en": "\nThought: I now know the final answer.\nFinal answer: ",
"zh": "\nThought: 我会作答了。\nFinal answer: ",
}
_messages = messages messages_with_fncall = messages
messages = [] messages = []
for m_idx, m in enumerate(_messages): for m_idx, m in enumerate(messages_with_fncall):
role, content, func_call = m.role, m.content, m.function_call role, content, func_call = m.role, m.content, m.function_call
if content: content = content or ''
content = content.lstrip("\n").rstrip() content = content.lstrip('\n').rstrip()
if role == "function": if role == 'function':
if (len(messages) == 0) or (messages[-1].role != "assistant"): if (len(messages) == 0) or (messages[-1].role != 'assistant'):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid request: Expecting role assistant before role function.", detail=
) 'Invalid request: Expecting role assistant before role function.',
messages[-1].content += f"\nObservation: {content}" )
if m_idx == len(_messages) - 1: messages[-1].content += f'\nObservation: {content}'
messages[-1].content += "\nThought:" if m_idx == len(messages_with_fncall) - 1:
elif role == "assistant": # add a prefix for text completion
messages[-1].content += '\nThought:'
elif role == 'assistant':
if len(messages) == 0: if len(messages) == 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid request: Expecting role user before role assistant.", detail=
'Invalid request: Expecting role user before role assistant.',
) )
last_msg = messages[-1].content
last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
if func_call is None: if func_call is None:
if functions: if functions:
content = dummy_thought["zh" if last_msg_has_zh else "en"] + content content = f'Thought: I now know the final answer.\nFinal Answer: {content}'
else:
f_name, f_args = func_call["name"], func_call["arguments"]
if not content:
if last_msg_has_zh:
content = f"Thought: 我可以使用 {f_name} API。"
else: else:
content = f"Thought: I can use {f_name}." f_name, f_args = func_call['name'], func_call['arguments']
content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}" if not content.startswith('Thought:'):
if messages[-1].role == "user": content = f'Thought: {content}'
content = f'{content}\nAction: {f_name}\nAction Input: {f_args}'
if messages[-1].role == 'user':
messages.append( messages.append(
ChatMessage(role="assistant", content=content.lstrip("\n").rstrip()) ChatMessage(role='assistant',
) content=content.lstrip('\n').rstrip()))
else: else:
messages[-1].content += content messages[-1].content += '\n' + content
elif role == "user": elif role == 'user':
messages.append( messages.append(
ChatMessage(role="user", content=content.lstrip("\n").rstrip()) ChatMessage(role='user',
) content=content.lstrip('\n').rstrip()))
else: else:
raise HTTPException( raise HTTPException(
status_code=400, detail=f"Invalid request: Incorrect role {role}." status_code=400,
) detail=f'Invalid request: Incorrect role {role}.')
query = _TEXT_COMPLETION_CMD query = _TEXT_COMPLETION_CMD
if messages[-1].role == "user": if messages[-1].role == 'user':
query = messages[-1].content query = messages[-1].content
messages = messages[:-1] messages = messages[:-1]
if len(messages) % 2 != 0: if len(messages) % 2 != 0:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail='Invalid request')
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range(0, len(messages), 2): for i in range(0, len(messages), 2):
if messages[i].role == "user" and messages[i + 1].role == "assistant": if messages[i].role == 'user' and messages[i + 1].role == 'assistant':
usr_msg = messages[i].content.lstrip("\n").rstrip() usr_msg = messages[i].content.lstrip('\n').rstrip()
bot_msg = messages[i + 1].content.lstrip("\n").rstrip() bot_msg = messages[i + 1].content.lstrip('\n').rstrip()
if system and (i == len(messages) - 2): if instruction and (i == len(messages) - 2):
usr_msg = f"{system}\n\nQuestion: {usr_msg}" usr_msg = f'{instruction}\n\nQuestion: {usr_msg}'
system = "" instruction = ''
for t in dummy_thought.values():
t = t.lstrip("\n")
if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
bot_msg = bot_msg[len(t) :]
history.append([usr_msg, bot_msg]) history.append([usr_msg, bot_msg])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", detail=
'Invalid request: Expecting exactly one user (or function) role before every assistant role.',
) )
if system: if instruction:
assert query is not _TEXT_COMPLETION_CMD assert query is not _TEXT_COMPLETION_CMD
query = f"{system}\n\nQuestion: {query}" query = f'{instruction}\n\nQuestion: {query}'
return query, history return query, history, system
def parse_response(response): def parse_response(response):
func_name, func_args = "", "" func_name, func_args = '', ''
i = response.rfind("\nAction:") i = response.find('\nAction:')
j = response.rfind("\nAction Input:") j = response.find('\nAction Input:')
k = response.rfind("\nObservation:") k = response.find('\nObservation:')
if 0 <= i < j: # If the text has `Action` and `Action input`, if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`, if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM, # then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word. # because the output text may have discarded the stop word.
response = response.rstrip() + "\nObservation:" # Add it back. response = response.rstrip() + '\nObservation:' # Add it back.
k = response.rfind("\nObservation:") k = response.find('\nObservation:')
func_name = response[i + len("\nAction:") : j].strip() func_name = response[i + len('\nAction:'):j].strip()
func_args = response[j + len("\nAction Input:") : k].strip() func_args = response[j + len('\nAction Input:'):k].strip()
if func_name: if func_name:
response = response[:i]
t = response.find('Thought: ')
if t >= 0:
response = response[t + len('Thought: '):]
response = response.strip()
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage( message=ChatMessage(
role="assistant", role='assistant',
content=response[:i], content=response,
function_call={"name": func_name, "arguments": func_args}, function_call={
'name': func_name,
'arguments': func_args
},
), ),
finish_reason="function_call", finish_reason='function_call',
) )
return choice_data return choice_data
z = response.rfind("\nFinal Answer: ")
z = response.rfind('\nFinal Answer: ')
if z >= 0: if z >= 0:
response = response[z + len("\nFinal Answer: ") :] response = response[z + len('\nFinal Answer: '):]
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role='assistant', content=response),
finish_reason="stop", finish_reason='stop',
) )
return choice_data return choice_data
# completion mode, not chat mode # completion mode, not chat mode
def text_complete_last_message(history, stop_words_ids, gen_kwargs): def text_complete_last_message(history, stop_words_ids, gen_kwargs, system):
im_start = "<|im_start|>" im_start = '<|im_start|>'
im_end = "<|im_end|>" im_end = '<|im_end|>'
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" prompt = f'{im_start}system\n{system}{im_end}'
for i, (query, response) in enumerate(history): for i, (query, response) in enumerate(history):
query = query.lstrip("\n").rstrip() query = query.lstrip('\n').rstrip()
response = response.lstrip("\n").rstrip() response = response.lstrip('\n').rstrip()
prompt += f"\n{im_start}user\n{query}{im_end}" prompt += f'\n{im_start}user\n{query}{im_end}'
prompt += f"\n{im_start}assistant\n{response}{im_end}" prompt += f'\n{im_start}assistant\n{response}{im_end}'
prompt = prompt[:-len(im_end)] prompt = prompt[:-len(im_end)]
_stop_words_ids = [tokenizer.encode(im_end)] _stop_words_ids = [tokenizer.encode(im_end)]
@ -369,20 +371,24 @@ def text_complete_last_message(history, stop_words_ids, gen_kwargs):
stop_words_ids = _stop_words_ids stop_words_ids = _stop_words_ids
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device) input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0] output = model.generate(input_ids,
output = tokenizer.decode(output, errors="ignore") stop_words_ids=stop_words_ids,
**gen_kwargs).tolist()[0]
output = tokenizer.decode(output, errors='ignore')
assert output.startswith(prompt) assert output.startswith(prompt)
output = output[len(prompt):] output = output[len(prompt):]
output = trim_stop_words(output, ["<|endoftext|>", im_end]) output = trim_stop_words(output, ['<|endoftext|>', im_end])
print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>") print(f'<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>')
return output return output
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer global model, tokenizer
gen_kwargs = {} gen_kwargs = {}
if request.top_k is not None:
gen_kwargs['top_k'] = request.top_k
if request.temperature is not None: if request.temperature is not None:
if request.temperature < 0.01: if request.temperature < 0.01:
gen_kwargs['top_k'] = 1 # greedy decoding gen_kwargs['top_k'] = 1 # greedy decoding
@ -395,32 +401,46 @@ async def create_chat_completion(request: ChatCompletionRequest):
stop_words = add_extra_stop_words(request.stop) stop_words = add_extra_stop_words(request.stop)
if request.functions: if request.functions:
stop_words = stop_words or [] stop_words = stop_words or []
if "Observation:" not in stop_words: if 'Observation:' not in stop_words:
stop_words.append("Observation:") stop_words.append('Observation:')
query, history = parse_messages(request.messages, request.functions) query, history, system = parse_messages(request.messages,
request.functions)
if request.stream: if request.stream:
if request.functions: if request.functions:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Invalid request: Function calling is not yet implemented for stream mode.", detail=
) 'Invalid request: Function calling is not yet implemented for stream mode.',
generate = predict(query, history, request.model, stop_words, gen_kwargs) )
return EventSourceResponse(generate, media_type="text/event-stream") generate = predict(query,
history,
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None request.model,
stop_words,
gen_kwargs,
system=system)
return EventSourceResponse(generate, media_type='text/event-stream')
stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None
if query is _TEXT_COMPLETION_CMD: if query is _TEXT_COMPLETION_CMD:
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs) response = text_complete_last_message(history,
stop_words_ids=stop_words_ids,
gen_kwargs=gen_kwargs,
system=system)
else: else:
response, _ = model.chat( response, _ = model.chat(
tokenizer, tokenizer,
query, query,
history=history, history=history,
system=system,
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
**gen_kwargs **gen_kwargs,
) )
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>") print('<chat>')
pprint(history, indent=2)
print(f'{query}\n<!-- *** -->\n{response}\n</chat>')
_gc() _gc()
response = trim_stop_words(response, stop_words) response = trim_stop_words(response, stop_words)
@ -429,12 +449,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
else: else:
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role='assistant', content=response),
finish_reason="stop", finish_reason='stop',
)
return ChatCompletionResponse(
model=request.model, choices=[choice_data], object="chat.completion"
) )
return ChatCompletionResponse(model=request.model,
choices=[choice_data],
object='chat.completion')
def _dump_json(data: BaseModel, *args, **kwargs) -> str: def _dump_json(data: BaseModel, *args, **kwargs) -> str:
@ -445,28 +465,37 @@ def _dump_json(data: BaseModel, *args, **kwargs) -> str:
async def predict( async def predict(
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict, query: str,
history: List[List[str]],
model_id: str,
stop_words: List[str],
gen_kwargs: Dict,
system: str,
): ):
global model, tokenizer global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None index=0, delta=DeltaMessage(role='assistant'), finish_reason=None)
) chunk = ChatCompletionResponse(model=model_id,
chunk = ChatCompletionResponse( choices=[choice_data],
model=model_id, choices=[choice_data], object="chat.completion.chunk" object='chat.completion.chunk')
) yield '{}'.format(_dump_json(chunk, exclude_unset=True))
yield "{}".format(_dump_json(chunk, exclude_unset=True))
current_length = 0 current_length = 0
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None
if stop_words: if stop_words:
# TODO: It's a little bit tricky to trim stop words in the stream mode. # TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Invalid request: custom stop words are not yet supported for stream mode.", detail=
) 'Invalid request: custom stop words are not yet supported for stream mode.',
response_generator = model.chat_stream(
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
) )
response_generator = model.chat_stream(tokenizer,
query,
history=history,
stop_words_ids=stop_words_ids,
system=system,
**gen_kwargs)
for new_response in response_generator: for new_response in response_generator:
if len(new_response) == current_length: if len(new_response) == current_length:
continue continue
@ -475,21 +504,20 @@ async def predict(
current_length = len(new_response) current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None index=0, delta=DeltaMessage(content=new_text), finish_reason=None)
) chunk = ChatCompletionResponse(model=model_id,
chunk = ChatCompletionResponse( choices=[choice_data],
model=model_id, choices=[choice_data], object="chat.completion.chunk" object='chat.completion.chunk')
) yield '{}'.format(_dump_json(chunk, exclude_unset=True))
yield "{}".format(_dump_json(chunk, exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(index=0,
choice_data = ChatCompletionResponseStreamChoice( delta=DeltaMessage(),
index=0, delta=DeltaMessage(), finish_reason="stop" finish_reason='stop')
) chunk = ChatCompletionResponse(model=model_id,
chunk = ChatCompletionResponse( choices=[choice_data],
model=model_id, choices=[choice_data], object="chat.completion.chunk" object='chat.completion.chunk')
) yield '{}'.format(_dump_json(chunk, exclude_unset=True))
yield "{}".format(_dump_json(chunk, exclude_unset=True)) yield '[DONE]'
yield "[DONE]"
_gc() _gc()
@ -497,36 +525,39 @@ async def predict(
def _get_args(): def _get_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
"-c", '-c',
"--checkpoint-path", '--checkpoint-path',
type=str, type=str,
default="Qwen/Qwen-7B-Chat", default='Qwen/Qwen-7B-Chat',
help="Checkpoint name or path, default to %(default)r", help='Checkpoint name or path, default to %(default)r',
) )
parser.add_argument('--api-auth', help='API authentication credentials')
parser.add_argument('--cpu-only',
action='store_true',
help='Run demo with CPU only')
parser.add_argument('--server-port',
type=int,
default=8000,
help='Demo server port.')
parser.add_argument( parser.add_argument(
"--api-auth", help="API authentication credentials" '--server-name',
) type=str,
parser.add_argument( default='127.0.0.1',
"--cpu-only", action="store_true", help="Run demo with CPU only" help=
) 'Demo server name. Default: 127.0.0.1, which is only visible from the local computer.'
parser.add_argument( ' If you want other computers to access your server, use 0.0.0.0 instead.',
"--server-port", type=int, default=8000, help="Demo server port."
) )
parser.add_argument( parser.add_argument(
"--server-name", '--disable-gc',
type=str, action='store_true',
default="127.0.0.1", help='Disable GC after each response generated.',
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
" If you want other computers to access your server, use 0.0.0.0 instead.",
) )
parser.add_argument("--disable-gc", action="store_true",
help="Disable GC after each response generated.")
args = parser.parse_args() args = parser.parse_args()
return args return args
if __name__ == "__main__": if __name__ == '__main__':
args = _get_args() args = _get_args()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
@ -536,14 +567,14 @@ if __name__ == "__main__":
) )
if args.api_auth: if args.api_auth:
app.add_middleware( app.add_middleware(BasicAuthMiddleware,
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1] username=args.api_auth.split(':')[0],
) password=args.api_auth.split(':')[1])
if args.cpu_only: if args.cpu_only:
device_map = "cpu" device_map = 'cpu'
else: else:
device_map = "auto" device_map = 'auto'
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path, args.checkpoint_path,

Loading…
Cancel
Save