@ -3,18 +3,22 @@
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
from argparse import ArgumentParser
import re
import copy
import json
import time
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from typing import Dict , List , Literal , Optional , Union
import torch
import uvicorn
from pydantic import BaseModel , Field
from fastapi import FastAPI , HTTPException
from fastapi . middleware . cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any , Dict , List , Literal , Optional , Union
from transformers import AutoTokenizer , AutoModel , AutoModel ForCausalLM
from pydantic import BaseModel , Field
from sse_starlette. sse import EventSourceResponse
from transformers import AutoTokenizer , AutoModel ForCausalLM
from transformers . generation import GenerationConfig
from sse_starlette . sse import ServerSentEvent , EventSourceResponse
@asynccontextmanager
@ -52,8 +56,9 @@ class ModelList(BaseModel):
class ChatMessage ( BaseModel ) :
role : Literal [ " user " , " assistant " , " system " ]
content : str
role : Literal [ " user " , " assistant " , " system " , " function " ]
content : Optional [ str ]
function_call : Optional [ Dict ] = None
class DeltaMessage ( BaseModel ) :
@ -64,17 +69,18 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest ( BaseModel ) :
model : str
messages : List [ ChatMessage ]
functions : Optional [ List [ Dict ] ] = None
temperature : Optional [ float ] = None
top_p : Optional [ float ] = None
max_length : Optional [ int ] = None
stream : Optional [ bool ] = False
stop : Optional [ List [ str ] ] = [ ]
stop : Optional [ List [ str ] ] = None
class ChatCompletionResponseChoice ( BaseModel ) :
index : int
message : ChatMessage
finish_reason : Literal [ " stop " , " length " ]
finish_reason : Literal [ " stop " , " length " , " function_call " ]
class ChatCompletionResponseStreamChoice ( BaseModel ) :
@ -86,7 +92,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse ( BaseModel ) :
model : str
object : Literal [ " chat.completion " , " chat.completion.chunk " ]
choices : List [ Union [ ChatCompletionResponseChoice , ChatCompletionResponseStreamChoice ] ]
choices : List [
Union [ ChatCompletionResponseChoice , ChatCompletionResponseStreamChoice ]
]
created : Optional [ int ] = Field ( default_factory = lambda : int ( time . time ( ) ) )
@ -97,70 +105,319 @@ async def list_models():
return ModelList ( data = [ model_card ] )
# To work around that unpleasant leading-\n tokenization issue!
def add_extra_stop_words ( stop_words ) :
if stop_words :
_stop_words = [ ]
_stop_words . extend ( stop_words )
for x in stop_words :
s = x . lstrip ( " \n " )
if s and ( s not in _stop_words ) :
_stop_words . append ( s )
return _stop_words
return stop_words
def trim_stop_words ( response , stop_words ) :
if stop_words :
for stop in stop_words :
idx = response . find ( stop )
if idx != - 1 :
response = response [ : idx ]
return response
TOOL_DESC = """ {name_for_model} : Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} """
REACT_INSTRUCTION = """ Answer the following questions as best you can. You have acesss to the following APIs:
{ tools_text }
Use the following format :
Question : the input question you must answer
Thought : you should always think about what to do
Action : the action to take , should be one of [ { tools_name_text } ]
Action Input : the input to the action
Observation : the result of the action
. . . ( this Thought / Action / Action Input / Observation can be repeated zero or more times )
Thought : I now know the final answer
Final Answer : the final answer to the original input question
Begin ! """
_TEXT_COMPLETION_CMD = object ( )
#
# Temporarily, the system role does not work as expected.
# We advise that you write the setups for role-play in your query,
# i.e., use the user role instead of the system role.
#
# TODO: Use real system role when the model is ready.
#
def parse_messages ( messages , functions ) :
if all ( m . role != " user " for m in messages ) :
raise HTTPException (
status_code = 400 ,
detail = f " Invalid request: Expecting at least one user message. " ,
)
messages = copy . deepcopy ( messages )
default_system = " You are a helpful assistant. "
system = " "
if messages [ 0 ] . role == " system " :
system = messages . pop ( 0 ) . content . lstrip ( " \n " ) . rstrip ( )
if system == default_system :
system = " "
if functions :
tools_text = [ ]
tools_name_text = [ ]
for func_info in functions :
name = func_info . get ( " name " , " " )
name_m = func_info . get ( " name_for_model " , name )
name_h = func_info . get ( " name_for_human " , name )
desc = func_info . get ( " description " , " " )
desc_m = func_info . get ( " description_for_model " , desc )
tool = TOOL_DESC . format (
name_for_model = name_m ,
name_for_human = name_h ,
# Hint: You can add the following format requirements in description:
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model = desc_m ,
parameters = json . dumps ( func_info [ " parameters " ] , ensure_ascii = False ) ,
)
tools_text . append ( tool )
tools_name_text . append ( name_m )
tools_text = " \n \n " . join ( tools_text )
tools_name_text = " , " . join ( tools_name_text )
system + = " \n \n " + REACT_INSTRUCTION . format (
tools_text = tools_text ,
tools_name_text = tools_name_text ,
)
system = system . lstrip ( " \n " ) . rstrip ( )
dummy_thought = {
" en " : " \n Thought: I now know the final answer. \n Final answer: " ,
" zh " : " \n Thought: 我会作答了。 \n Final answer: " ,
}
_messages = messages
messages = [ ]
for m_idx , m in enumerate ( _messages ) :
role , content , func_call = m . role , m . content , m . function_call
if content :
content = content . lstrip ( " \n " ) . rstrip ( )
if role == " function " :
if ( len ( messages ) == 0 ) or ( messages [ - 1 ] . role != " assistant " ) :
raise HTTPException (
status_code = 400 ,
detail = f " Invalid request: Expecting role assistant before role function. " ,
)
messages [ - 1 ] . content + = f " \n Observation: { content } "
if m_idx == len ( _messages ) - 1 :
messages [ - 1 ] . content + = " \n Thought: "
elif role == " assistant " :
if len ( messages ) == 0 :
raise HTTPException (
status_code = 400 ,
detail = f " Invalid request: Expecting role user before role assistant. " ,
)
last_msg = messages [ - 1 ] . content
last_msg_has_zh = len ( re . findall ( r " [ \ u4e00- \ u9fff]+ " , last_msg ) ) > 0
if func_call is None :
if functions :
content = dummy_thought [ " zh " if last_msg_has_zh else " en " ] + content
else :
f_name , f_args = func_call [ " name " ] , func_call [ " arguments " ]
if not content :
if last_msg_has_zh :
content = f " Thought: 我可以使用 { f_name } API。 "
else :
content = f " Thought: I can use { f_name } . "
content = f " \n { content } \n Action: { f_name } \n Action Input: { f_args } "
if messages [ - 1 ] . role == " user " :
messages . append (
ChatMessage ( role = " assistant " , content = content . lstrip ( " \n " ) . rstrip ( ) )
)
else :
messages [ - 1 ] . content + = content
elif role == " user " :
messages . append (
ChatMessage ( role = " user " , content = content . lstrip ( " \n " ) . rstrip ( ) )
)
else :
raise HTTPException (
status_code = 400 , detail = f " Invalid request: Incorrect role { role } . "
)
query = _TEXT_COMPLETION_CMD
if messages [ - 1 ] . role == " user " :
query = messages [ - 1 ] . content
messages = messages [ : - 1 ]
if len ( messages ) % 2 != 0 :
raise HTTPException ( status_code = 400 , detail = " Invalid request " )
history = [ ] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range ( 0 , len ( messages ) , 2 ) :
if messages [ i ] . role == " user " and messages [ i + 1 ] . role == " assistant " :
usr_msg = messages [ i ] . content . lstrip ( " \n " ) . rstrip ( )
bot_msg = messages [ i + 1 ] . content . lstrip ( " \n " ) . rstrip ( )
if system and ( i == len ( messages ) - 2 ) :
usr_msg = f " { system } \n \n Question: { usr_msg } "
system = " "
for t in dummy_thought . values ( ) :
t = t . lstrip ( " \n " )
if bot_msg . startswith ( t ) and ( " \n Action: " in bot_msg ) :
bot_msg = bot_msg [ len ( t ) : ]
history . append ( [ usr_msg , bot_msg ] )
else :
raise HTTPException (
status_code = 400 ,
detail = " Invalid request: Expecting exactly one user (or function) role before every assistant role. " ,
)
if system :
assert query is not _TEXT_COMPLETION_CMD
query = f " { system } \n \n Question: { query } "
return query , history
def parse_response ( response ) :
func_name , func_args = " " , " "
i = response . rfind ( " \n Action: " )
j = response . rfind ( " \n Action Input: " )
k = response . rfind ( " \n Observation: " )
if 0 < = i < j : # If the text has `Action` and `Action input`,
if k < j : # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
response = response . rstrip ( ) + " \n Observation: " # Add it back.
k = response . rfind ( " \n Observation: " )
func_name = response [ i + len ( " \n Action: " ) : j ] . strip ( )
func_args = response [ j + len ( " \n Action Input: " ) : k ] . strip ( )
if func_name :
choice_data = ChatCompletionResponseChoice (
index = 0 ,
message = ChatMessage (
role = " assistant " ,
content = response [ : i ] ,
function_call = { " name " : func_name , " arguments " : func_args } ,
) ,
finish_reason = " function_call " ,
)
return choice_data
z = response . rfind ( " \n Final Answer: " )
if z > = 0 :
response = response [ z + len ( " \n Final Answer: " ) : ]
choice_data = ChatCompletionResponseChoice (
index = 0 ,
message = ChatMessage ( role = " assistant " , content = response ) ,
finish_reason = " stop " ,
)
return choice_data
# completion mode, not chat mode
def text_complete_last_message ( history , stop_words_ids ) :
im_start = " <|im_start|> "
im_end = " <|im_end|> "
prompt = f " { im_start } system \n You are a helpful assistant. { im_end } "
for i , ( query , response ) in enumerate ( history ) :
query = query . lstrip ( " \n " ) . rstrip ( )
response = response . lstrip ( " \n " ) . rstrip ( )
prompt + = f " \n { im_start } user \n { query } { im_end } "
prompt + = f " \n { im_start } assistant \n { response } { im_end } "
prompt = prompt [ : - len ( im_end ) ]
_stop_words_ids = [ tokenizer . encode ( im_end ) ]
if stop_words_ids :
for s in stop_words_ids :
_stop_words_ids . append ( s )
stop_words_ids = _stop_words_ids
input_ids = torch . tensor ( [ tokenizer . encode ( prompt ) ] ) . to ( model . device )
output = model . generate ( input_ids , stop_words_ids = stop_words_ids ) . tolist ( ) [ 0 ]
output = tokenizer . decode ( output , errors = " ignore " )
assert output . startswith ( prompt )
output = output [ len ( prompt ) : ]
output = trim_stop_words ( output , [ " <|endoftext|> " , im_end ] )
print ( f " <completion> \n { prompt } \n <!-- *** --> \n { output } \n </completion> " )
return output
@app.post ( " /v1/chat/completions " , response_model = ChatCompletionResponse )
async def create_chat_completion ( request : ChatCompletionRequest ) :
global model , tokenizer
if request . messages [ - 1 ] . role != " user " :
raise HTTPException ( status_code = 400 , detail = " Invalid request " )
query = request . messages [ - 1 ] . content
stop_words = request . stop
stop_words . extend ( list ( map ( lambda x : x [ 1 : ] , filter ( lambda x : x . startswith ( " \n " ) , stop_words ) ) ) )
prev_messages = request . messages [ : - 1 ]
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
# if len(prev_messages) > 0 and prev_messages[0].role == "system":
# query = prev_messages.pop(0).content + query
history = [ ]
if len ( prev_messages ) % 2 == 0 :
for i in range ( 0 , len ( prev_messages ) , 2 ) :
if prev_messages [ i ] . role == " user " and prev_messages [ i + 1 ] . role == " assistant " :
history . append ( [ prev_messages [ i ] . content , prev_messages [ i + 1 ] . content ] )
else :
raise HTTPException ( status_code = 400 , detail = " Invalid request. " )
else :
raise HTTPException ( status_code = 400 , detail = " Invalid request. " )
stop_words = add_extra_stop_words ( request . stop )
if request . functions :
stop_words = stop_words or [ ]
if " Observation: " not in stop_words :
stop_words . append ( " Observation: " )
query , history = parse_messages ( request . messages , request . functions )
if request . stream :
if request . functions :
raise HTTPException (
status_code = 400 ,
detail = " Invalid request: Function calling is not yet implemented for stream mode. " ,
)
generate = predict ( query , history , request . model , stop_words )
return EventSourceResponse ( generate , media_type = " text/event-stream " )
if stop_words :
react_stop_words_tokens = [ tokenizer . encode ( stop_ ) for stop_ in stop_words ]
response , _ = model . chat ( tokenizer , query , history = history , stop_words_ids = react_stop_words_tokens )
for stop_ in stop_words :
if response . endswith ( stop_ ) :
response = response [ : response . find ( stop_ ) ]
stop_words_ids = [ tokenizer . encode ( s ) for s in stop_words ] if stop_words else None
if query is _TEXT_COMPLETION_CMD :
response = text_complete_last_message ( history , stop_words_ids = stop_words_ids )
else :
response , _ = model . chat (
tokenizer ,
query ,
history = history ,
stop_words_ids = stop_words_ids ,
append_history = False ,
)
print ( f " <chat> \n { history } \n { query } \n <!-- *** --> \n { response } \n </chat> " )
response = trim_stop_words ( response , stop_words )
if request . functions :
choice_data = parse_response ( response )
else :
response , _ = model . chat ( tokenizer , query , history = history )
choice_data = ChatCompletionResponseChoice (
index = 0 ,
message = ChatMessage ( role = " assistant " , content = response ) ,
finish_reason = " stop "
finish_reason = " stop " ,
)
return ChatCompletionResponse (
model = request . model , choices = [ choice_data ] , object = " chat.completion "
)
return ChatCompletionResponse ( model = request . model , choices = [ choice_data ] , object = " chat.completion " )
async def predict ( query : str , history : List [ List [ str ] ] , model_id : str , stop_words : List [ str ] ) :
async def predict (
query : str , history : List [ List [ str ] ] , model_id : str , stop_words : List [ str ]
) :
global model , tokenizer
assert stop_words == [ ] , " in stream format, stop word is output "
choice_data = ChatCompletionResponseStreamChoice (
index = 0 ,
delta = DeltaMessage ( role = " assistant " ) ,
finish_reason = None
index = 0 , delta = DeltaMessage ( role = " assistant " ) , finish_reason = None
)
chunk = ChatCompletionResponse (
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
)
chunk = ChatCompletionResponse ( model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk " )
yield " {} " . format ( chunk . model_dump_json ( exclude_unset = True ) )
current_length = 0
stop_words_ids = [ tokenizer . encode ( s ) for s in stop_words ] if stop_words else None
if stop_words :
react_stop_words_tokens = [ tokenizer . encode ( stop_ ) for stop_ in stop_words ]
response_generator = model . chat_stream ( tokenizer , query , history = history , stop_words_ids = react_stop_words_tokens )
else :
response_generator = model . chat_stream ( tokenizer , query , history = history )
# TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException (
status_code = 400 ,
detail = " Invalid request: custom stop words are not yet supported for stream mode. " ,
)
response_generator = model . chat_stream (
tokenizer , query , history = history , stop_words_ids = stop_words_ids
)
for new_response in response_generator :
if len ( new_response ) == current_length :
continue
@ -169,32 +426,41 @@ async def predict(query: str, history: List[List[str]], model_id: str, stop_word
current_length = len ( new_response )
choice_data = ChatCompletionResponseStreamChoice (
index = 0 ,
delta = DeltaMessage ( content = new_text ) ,
finish_reason = None
index = 0 , delta = DeltaMessage ( content = new_text ) , finish_reason = None
)
chunk = ChatCompletionResponse (
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
)
chunk = ChatCompletionResponse ( model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk " )
yield " {} " . format ( chunk . model_dump_json ( exclude_unset = True ) )
choice_data = ChatCompletionResponseStreamChoice (
index = 0 ,
delta = DeltaMessage ( ) ,
finish_reason = " stop "
index = 0 , delta = DeltaMessage ( ) , finish_reason = " stop "
)
chunk = ChatCompletionResponse (
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
)
chunk = ChatCompletionResponse ( model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk " )
yield " {} " . format ( chunk . model_dump_json ( exclude_unset = True ) )
yield ' [DONE] '
yield " [DONE] "
def _get_args ( ) :
parser = ArgumentParser ( )
parser . add_argument ( " -c " , " --checkpoint-path " , type = str , default = ' QWen/QWen-7B-Chat ' ,
help = " Checkpoint name or path, default to %(default)r " )
parser . add_argument ( " --cpu-only " , action = " store_true " , help = " Run demo with CPU only " )
parser . add_argument ( " --server-port " , type = int , default = 8000 ,
help = " Demo server port. " )
parser . add_argument ( " --server-name " , type = str , default = " 127.0.0.1 " ,
help = " Demo server name. " )
parser . add_argument (
" -c " ,
" --checkpoint-path " ,
type = str ,
default = " QWen/QWen-7B-Chat " ,
help = " Checkpoint name or path, default to %(default)r " ,
)
parser . add_argument (
" --cpu-only " , action = " store_true " , help = " Run demo with CPU only "
)
parser . add_argument (
" --server-port " , type = int , default = 8000 , help = " Demo server port. "
)
parser . add_argument (
" --server-name " , type = str , default = " 127.0.0.1 " , help = " Demo server name. "
)
args = parser . parse_args ( )
return args
@ -204,7 +470,9 @@ if __name__ == "__main__":
args = _get_args ( )
tokenizer = AutoTokenizer . from_pretrained (
args . checkpoint_path , trust_remote_code = True , resume_download = True ,
args . checkpoint_path ,
trust_remote_code = True ,
resume_download = True ,
)
if args . cpu_only :
@ -220,7 +488,9 @@ if __name__ == "__main__":
) . eval ( )
model . generation_config = GenerationConfig . from_pretrained (
args . checkpoint_path , trust_remote_code = True , resume_download = True ,
args . checkpoint_path ,
trust_remote_code = True ,
resume_download = True ,
)
uvicorn . run ( app , host = args . server_name , port = args . server_port , workers = 1 )