@ -1,14 +1,16 @@
# coding=utf-8
# Requirement:
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# pip install "openai<1.0"
# Usage: python openai_api.py
# Usage:
# python openai_api.py
# Visit http://localhost:8000/docs for documents.
# Visit http://localhost:8000/docs for documents.
import re
import base64
import copy
import copy
import json
import json
import time
import time
from argparse import ArgumentParser
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager
from pprint import pprint
from typing import Dict , List , Literal , Optional , Union
from typing import Dict , List , Literal , Optional , Union
import torch
import torch
@ -17,20 +19,22 @@ from fastapi import FastAPI, HTTPException
from fastapi . middleware . cors import CORSMiddleware
from fastapi . middleware . cors import CORSMiddleware
from pydantic import BaseModel , Field
from pydantic import BaseModel , Field
from sse_starlette . sse import EventSourceResponse
from sse_starlette . sse import EventSourceResponse
from transformers import AutoTokenizer , AutoModelForCausalLM
from transformers . generation import GenerationConfig
from starlette . middleware . base import BaseHTTPMiddleware
from starlette . middleware . base import BaseHTTPMiddleware
from starlette . requests import Request
from starlette . requests import Request
from starlette . responses import Response
from starlette . responses import Response
import base64
from transformers import AutoModelForCausalLM , AutoTokenizer
from transformers . generation import GenerationConfig
class BasicAuthMiddleware ( BaseHTTPMiddleware ) :
class BasicAuthMiddleware ( BaseHTTPMiddleware ) :
def __init__ ( self , app , username : str , password : str ) :
def __init__ ( self , app , username : str , password : str ) :
super ( ) . __init__ ( app )
super ( ) . __init__ ( app )
self . required_credentials = base64 . b64encode ( f " { username } : { password } " . encode ( ) ) . decode ( )
self . required_credentials = base64 . b64encode (
f ' { username } : { password } ' . encode ( ) ) . decode ( )
async def dispatch ( self , request : Request , call_next ) :
async def dispatch ( self , request : Request , call_next ) :
authorization : str = request . headers . get ( " Authorization " )
authorization : str = request . headers . get ( ' Authorization ' )
if authorization :
if authorization :
try :
try :
schema , credentials = authorization . split ( )
schema , credentials = authorization . split ( )
@ -42,12 +46,14 @@ class BasicAuthMiddleware(BaseHTTPMiddleware):
headers = { ' WWW-Authenticate ' : ' Basic ' }
headers = { ' WWW-Authenticate ' : ' Basic ' }
return Response ( status_code = 401 , headers = headers )
return Response ( status_code = 401 , headers = headers )
def _gc ( forced : bool = False ) :
def _gc ( forced : bool = False ) :
global args
global args
if args . disable_gc and not forced :
if args . disable_gc and not forced :
return
return
import gc
import gc
gc . collect ( )
gc . collect ( )
if torch . cuda . is_available ( ) :
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
torch . cuda . empty_cache ( )
@ -63,36 +69,36 @@ app = FastAPI(lifespan=lifespan)
app . add_middleware (
app . add_middleware (
CORSMiddleware ,
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_origins = [ ' * ' ] ,
allow_credentials = True ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_methods = [ ' * ' ] ,
allow_headers = [ " * " ] ,
allow_headers = [ ' * ' ] ,
)
)
class ModelCard ( BaseModel ) :
class ModelCard ( BaseModel ) :
id : str
id : str
object : str = " model "
object : str = ' model '
created : int = Field ( default_factory = lambda : int ( time . time ( ) ) )
created : int = Field ( default_factory = lambda : int ( time . time ( ) ) )
owned_by : str = " owner "
owned_by : str = ' owner '
root : Optional [ str ] = None
root : Optional [ str ] = None
parent : Optional [ str ] = None
parent : Optional [ str ] = None
permission : Optional [ list ] = None
permission : Optional [ list ] = None
class ModelList ( BaseModel ) :
class ModelList ( BaseModel ) :
object : str = " list "
object : str = ' list '
data : List [ ModelCard ] = [ ]
data : List [ ModelCard ] = [ ]
class ChatMessage ( BaseModel ) :
class ChatMessage ( BaseModel ) :
role : Literal [ " user " , " assistant " , " system " , " function " ]
role : Literal [ ' user ' , ' assistant ' , ' system ' , ' function ' ]
content : Optional [ str ]
content : Optional [ str ]
function_call : Optional [ Dict ] = None
function_call : Optional [ Dict ] = None
class DeltaMessage ( BaseModel ) :
class DeltaMessage ( BaseModel ) :
role : Optional [ Literal [ " user " , " assistant " , " system " ] ] = None
role : Optional [ Literal [ ' user ' , ' assistant ' , ' system ' ] ] = None
content : Optional [ str ] = None
content : Optional [ str ] = None
@ -102,6 +108,7 @@ class ChatCompletionRequest(BaseModel):
functions : Optional [ List [ Dict ] ] = None
functions : Optional [ List [ Dict ] ] = None
temperature : Optional [ float ] = None
temperature : Optional [ float ] = None
top_p : Optional [ float ] = None
top_p : Optional [ float ] = None
top_k : Optional [ int ] = None
max_length : Optional [ int ] = None
max_length : Optional [ int ] = None
stream : Optional [ bool ] = False
stream : Optional [ bool ] = False
stop : Optional [ List [ str ] ] = None
stop : Optional [ List [ str ] ] = None
@ -109,29 +116,28 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice ( BaseModel ) :
class ChatCompletionResponseChoice ( BaseModel ) :
index : int
index : int
message : ChatMessage
message : Union[ ChatMessage]
finish_reason : Literal [ " stop " , " length " , " function_call " ]
finish_reason : Literal [ ' stop ' , ' length ' , ' function_call ' ]
class ChatCompletionResponseStreamChoice ( BaseModel ) :
class ChatCompletionResponseStreamChoice ( BaseModel ) :
index : int
index : int
delta : DeltaMessage
delta : DeltaMessage
finish_reason : Optional [ Literal [ " stop " , " length " ] ]
finish_reason : Optional [ Literal [ ' stop ' , ' length ' ] ]
class ChatCompletionResponse ( BaseModel ) :
class ChatCompletionResponse ( BaseModel ) :
model : str
model : str
object : Literal [ " chat.completion " , " chat.completion.chunk " ]
object : Literal [ ' chat.completion ' , ' chat.completion.chunk ' ]
choices : List [
choices : List [ Union [ ChatCompletionResponseChoice ,
Union [ ChatCompletionResponseChoice , ChatCompletionResponseStreamChoice ]
ChatCompletionResponseStreamChoice ] ]
]
created : Optional [ int ] = Field ( default_factory = lambda : int ( time . time ( ) ) )
created : Optional [ int ] = Field ( default_factory = lambda : int ( time . time ( ) ) )
@app.get ( " /v1/models " , response_model = ModelList )
@app.get ( ' /v1/models ' , response_model = ModelList )
async def list_models ( ) :
async def list_models ( ) :
global model_args
global model_args
model_card = ModelCard ( id = " gpt-3.5-turbo " )
model_card = ModelCard ( id = ' gpt-3.5-turbo ' )
return ModelList ( data = [ model_card ] )
return ModelList ( data = [ model_card ] )
@ -141,7 +147,7 @@ def add_extra_stop_words(stop_words):
_stop_words = [ ]
_stop_words = [ ]
_stop_words . extend ( stop_words )
_stop_words . extend ( stop_words )
for x in stop_words :
for x in stop_words :
s = x . lstrip ( " \n " )
s = x . lstrip ( ' \n ' )
if s and ( s not in _stop_words ) :
if s and ( s not in _stop_words ) :
_stop_words . append ( s )
_stop_words . append ( s )
return _stop_words
return _stop_words
@ -157,7 +163,10 @@ def trim_stop_words(response, stop_words):
return response
return response
TOOL_DESC = """ {name_for_model} : Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} """
TOOL_DESC = (
' {name_for_model} : Call this tool to interact with the {name_for_human} API. '
' What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} '
)
REACT_INSTRUCTION = """ Answer the following questions as best you can. You have access to the following APIs:
REACT_INSTRUCTION = """ Answer the following questions as best you can. You have access to the following APIs:
@ -179,37 +188,28 @@ Begin!"""
_TEXT_COMPLETION_CMD = object ( )
_TEXT_COMPLETION_CMD = object ( )
#
# Temporarily, the system role does not work as expected.
# We advise that you write the setups for role-play in your query,
# i.e., use the user role instead of the system role.
#
# TODO: Use real system role when the model is ready.
#
def parse_messages ( messages , functions ) :
def parse_messages ( messages , functions ) :
if all ( m . role != " user " for m in messages ) :
if all ( m . role != ' user ' for m in messages ) :
raise HTTPException (
raise HTTPException (
status_code = 400 ,
status_code = 400 ,
detail = f" Invalid request: Expecting at least one user message. " ,
detail = ' Invalid request: Expecting at least one user message. ' ,
)
)
messages = copy . deepcopy ( messages )
messages = copy . deepcopy ( messages )
default_system = " You are a helpful assistant. "
if messages [ 0 ] . role == ' system ' :
system = " "
system = messages . pop ( 0 ) . content . lstrip ( ' \n ' ) . rstrip ( )
if messages [ 0 ] . role == " system " :
else :
system = messages . pop ( 0 ) . content . lstrip ( " \n " ) . rstrip ( )
system = ' You are a helpful assistant. '
if system == default_system :
system = " "
if functions :
if functions :
tools_text = [ ]
tools_text = [ ]
tools_name_text = [ ]
tools_name_text = [ ]
for func_info in functions :
for func_info in functions :
name = func_info . get ( " name " , " " )
name = func_info . get ( ' name ' , ' ' )
name_m = func_info . get ( " name_for_model " , name )
name_m = func_info . get ( ' name_for_model ' , name )
name_h = func_info . get ( " name_for_human " , name )
name_h = func_info . get ( ' name_for_human ' , name )
desc = func_info . get ( " description " , " " )
desc = func_info . get ( ' description ' , ' ' )
desc_m = func_info . get ( " description_for_model " , desc )
desc_m = func_info . get ( ' description_for_model ' , desc )
tool = TOOL_DESC . format (
tool = TOOL_DESC . format (
name_for_model = name_m ,
name_for_model = name_m ,
name_for_human = name_h ,
name_for_human = name_h ,
@ -217,149 +217,151 @@ def parse_messages(messages, functions):
# "Format the arguments as a JSON object."
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model = desc_m ,
description_for_model = desc_m ,
parameters = json . dumps ( func_info [ " parameters " ] , ensure_ascii = False ) ,
parameters = json . dumps ( func_info [ ' parameters ' ] ,
ensure_ascii = False ) ,
)
)
tools_text . append ( tool )
tools_text . append ( tool )
tools_name_text . append ( name_m )
tools_name_text . append ( name_m )
tools_text = " \n \n " . join ( tools_text )
tools_text = ' \n \n ' . join ( tools_text )
tools_name_text = " , " . join ( tools_name_text )
tools_name_text = ' , ' . join ( tools_name_text )
system + = " \n \n " + REACT_INSTRUCTION . format (
instruction = ( REACT_INSTRUCTION . format (
tools_text = tools_text ,
tools_text = tools_text ,
tools_name_text = tools_name_text ,
tools_name_text = tools_name_text ,
)
) . lstrip ( ' \n ' ) . rstrip ( ) )
system = system . lstrip ( " \n " ) . rstrip ( )
else :
instruction = ' '
dummy_thought = {
" en " : " \n Thought: I now know the final answer. \n Final answer: " ,
" zh " : " \n Thought: 我会作答了。 \n Final answer: " ,
}
_ messages = messages
messages_with_fncall = messages
messages = [ ]
messages = [ ]
for m_idx , m in enumerate ( _ messages) :
for m_idx , m in enumerate ( messages_with_fncall ) :
role , content , func_call = m . role , m . content , m . function_call
role , content , func_call = m . role , m . content , m . function_call
if content :
content = content or ' '
content = content . lstrip ( " \n " ) . rstrip ( )
content = content . lstrip ( ' \n ' ) . rstrip ( )
if role == " function " :
if role == ' function ' :
if ( len ( messages ) == 0 ) or ( messages [ - 1 ] . role != " assistant " ) :
if ( len ( messages ) == 0 ) or ( messages [ - 1 ] . role != ' assistant ' ) :
raise HTTPException (
raise HTTPException (
status_code = 400 ,
status_code = 400 ,
detail = f " Invalid request: Expecting role assistant before role function. " ,
detail =
)
' Invalid request: Expecting role assistant before role function. ' ,
messages [ - 1 ] . content + = f " \n Observation: { content } "
)
if m_idx == len ( _messages ) - 1 :
messages [ - 1 ] . content + = f ' \n Observation: { content } '
messages [ - 1 ] . content + = " \n Thought: "
if m_idx == len ( messages_with_fncall ) - 1 :
elif role == " assistant " :
# add a prefix for text completion
messages [ - 1 ] . content + = ' \n Thought: '
elif role == ' assistant ' :
if len ( messages ) == 0 :
if len ( messages ) == 0 :
raise HTTPException (
raise HTTPException (
status_code = 400 ,
status_code = 400 ,
detail = f " Invalid request: Expecting role user before role assistant. " ,
detail =
' Invalid request: Expecting role user before role assistant. ' ,
)
)
last_msg = messages [ - 1 ] . content
last_msg_has_zh = len ( re . findall ( r " [ \ u4e00- \ u9fff]+ " , last_msg ) ) > 0
if func_call is None :
if func_call is None :
if functions :
if functions :
content = dummy_thought [ " zh " if last_msg_has_zh else " en " ] + content
content = f ' Thought: I now know the final answer. \n Final Answer: { content } '
else :
f_name , f_args = func_call [ " name " ] , func_call [ " arguments " ]
if not content :
if last_msg_has_zh :
content = f " Thought: 我可以使用 { f_name } API。 "
else :
else :
content = f " Thought: I can use { f_name } . "
f_name , f_args = func_call [ ' name ' ] , func_call [ ' arguments ' ]
content = f " \n { content } \n Action: { f_name } \n Action Input: { f_args } "
if not content . startswith ( ' Thought: ' ) :
if messages [ - 1 ] . role == " user " :
content = f ' Thought: { content } '
content = f ' { content } \n Action: { f_name } \n Action Input: { f_args } '
if messages [ - 1 ] . role == ' user ' :
messages . append (
messages . append (
ChatMessage ( role = " assistant " , content = content . lstrip ( " \n " ) . rstrip ( ) )
ChatMessage ( role = ' assistant ' ,
)
content = content . lstrip ( ' \n ' ) . rstrip ( ) ) )
else :
else :
messages [ - 1 ] . content + = content
messages [ - 1 ] . content + = ' \n ' + content
elif role == " user " :
elif role == ' user ' :
messages . append (
messages . append (
ChatMessage ( role = " user " , content = content . lstrip ( " \n " ) . rstrip ( ) )
ChatMessage ( role = ' user ' ,
)
content = content . lstrip ( ' \n ' ) . rstrip ( ) ) )
else :
else :
raise HTTPException (
raise HTTPException (
status_code = 400 , detail = f " Invalid request: Incorrect role { role } . "
status_code = 400 ,
)
detail = f ' Invalid request: Incorrect role { role } . ' )
query = _TEXT_COMPLETION_CMD
query = _TEXT_COMPLETION_CMD
if messages [ - 1 ] . role == " user " :
if messages [ - 1 ] . role == ' user ' :
query = messages [ - 1 ] . content
query = messages [ - 1 ] . content
messages = messages [ : - 1 ]
messages = messages [ : - 1 ]
if len ( messages ) % 2 != 0 :
if len ( messages ) % 2 != 0 :
raise HTTPException ( status_code = 400 , detail = " Invalid request " )
raise HTTPException ( status_code = 400 , detail = ' Invalid request ' )
history = [ ] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
history = [ ] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range ( 0 , len ( messages ) , 2 ) :
for i in range ( 0 , len ( messages ) , 2 ) :
if messages [ i ] . role == " user " and messages [ i + 1 ] . role == " assistant " :
if messages [ i ] . role == ' user ' and messages [ i + 1 ] . role == ' assistant ' :
usr_msg = messages [ i ] . content . lstrip ( " \n " ) . rstrip ( )
usr_msg = messages [ i ] . content . lstrip ( ' \n ' ) . rstrip ( )
bot_msg = messages [ i + 1 ] . content . lstrip ( " \n " ) . rstrip ( )
bot_msg = messages [ i + 1 ] . content . lstrip ( ' \n ' ) . rstrip ( )
if system and ( i == len ( messages ) - 2 ) :
if instruction and ( i == len ( messages ) - 2 ) :
usr_msg = f " { system } \n \n Question: { usr_msg } "
usr_msg = f ' { instruction } \n \n Question: { usr_msg } '
system = " "
instruction = ' '
for t in dummy_thought . values ( ) :
t = t . lstrip ( " \n " )
if bot_msg . startswith ( t ) and ( " \n Action: " in bot_msg ) :
bot_msg = bot_msg [ len ( t ) : ]
history . append ( [ usr_msg , bot_msg ] )
history . append ( [ usr_msg , bot_msg ] )
else :
else :
raise HTTPException (
raise HTTPException (
status_code = 400 ,
status_code = 400 ,
detail = " Invalid request: Expecting exactly one user (or function) role before every assistant role. " ,
detail =
' Invalid request: Expecting exactly one user (or function) role before every assistant role. ' ,
)
)
if system :
if instruction :
assert query is not _TEXT_COMPLETION_CMD
assert query is not _TEXT_COMPLETION_CMD
query = f " { system } \n \n Question: { query } "
query = f ' { instruction } \n \n Question: { query } '
return query , history
return query , history , system
def parse_response ( response ) :
def parse_response ( response ) :
func_name , func_args = " " , " "
func_name , func_args = ' ' , ' '
i = response . rfind( " \n Action: " )
i = response . find( ' \n Action: ' )
j = response . rfind( " \n Action Input: " )
j = response . find( ' \n Action Input: ' )
k = response . rfind( " \n Observation: " )
k = response . find( ' \n Observation: ' )
if 0 < = i < j : # If the text has `Action` and `Action input`,
if 0 < = i < j : # If the text has `Action` and `Action input`,
if k < j : # but does not contain `Observation`,
if k < j : # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
# because the output text may have discarded the stop word.
response = response . rstrip ( ) + " \n Observation: " # Add it back.
response = response . rstrip ( ) + ' \n Observation: ' # Add it back.
k = response . rfind ( " \n Observation: " )
k = response . find ( ' \n Observation: ' )
func_name = response [ i + len ( " \n Action: " ) : j ] . strip ( )
func_name = response [ i + len ( ' \n Action: ' ) : j ] . strip ( )
func_args = response [ j + len ( " \n Action Input: " ) : k ] . strip ( )
func_args = response [ j + len ( ' \n Action Input: ' ) : k ] . strip ( )
if func_name :
if func_name :
response = response [ : i ]
t = response . find ( ' Thought: ' )
if t > = 0 :
response = response [ t + len ( ' Thought: ' ) : ]
response = response . strip ( )
choice_data = ChatCompletionResponseChoice (
choice_data = ChatCompletionResponseChoice (
index = 0 ,
index = 0 ,
message = ChatMessage (
message = ChatMessage (
role = " assistant " ,
role = ' assistant ' ,
content = response [ : i ] ,
content = response ,
function_call = { " name " : func_name , " arguments " : func_args } ,
function_call = {
' name ' : func_name ,
' arguments ' : func_args
} ,
) ,
) ,
finish_reason = " function_call " ,
finish_reason = ' function_call ' ,
)
)
return choice_data
return choice_data
z = response . rfind ( " \n Final Answer: " )
z = response . rfind ( ' \n Final Answer: ' )
if z > = 0 :
if z > = 0 :
response = response [ z + len ( " \n Final Answer: " ) : ]
response = response [ z + len ( ' \n Final Answer: ' ) : ]
choice_data = ChatCompletionResponseChoice (
choice_data = ChatCompletionResponseChoice (
index = 0 ,
index = 0 ,
message = ChatMessage ( role = " assistant " , content = response ) ,
message = ChatMessage ( role = ' assistant ' , content = response ) ,
finish_reason = " stop " ,
finish_reason = ' stop ' ,
)
)
return choice_data
return choice_data
# completion mode, not chat mode
# completion mode, not chat mode
def text_complete_last_message ( history , stop_words_ids , gen_kwargs ):
def text_complete_last_message ( history , stop_words_ids , gen_kwargs , system ):
im_start = " <|im_start|> "
im_start = ' <|im_start|> '
im_end = " <|im_end|> "
im_end = ' <|im_end|> '
prompt = f " { im_start } system \n You are a helpful assistant. { im_end } "
prompt = f ' { im_start } system \n { system } { im_end } '
for i , ( query , response ) in enumerate ( history ) :
for i , ( query , response ) in enumerate ( history ) :
query = query . lstrip ( " \n " ) . rstrip ( )
query = query . lstrip ( ' \n ' ) . rstrip ( )
response = response . lstrip ( " \n " ) . rstrip ( )
response = response . lstrip ( ' \n ' ) . rstrip ( )
prompt + = f " \n { im_start } user \n { query } { im_end } "
prompt + = f ' \n { im_start } user \n { query } { im_end } '
prompt + = f " \n { im_start } assistant \n { response } { im_end } "
prompt + = f ' \n { im_start } assistant \n { response } { im_end } '
prompt = prompt [ : - len ( im_end ) ]
prompt = prompt [ : - len ( im_end ) ]
_stop_words_ids = [ tokenizer . encode ( im_end ) ]
_stop_words_ids = [ tokenizer . encode ( im_end ) ]
@ -369,20 +371,24 @@ def text_complete_last_message(history, stop_words_ids, gen_kwargs):
stop_words_ids = _stop_words_ids
stop_words_ids = _stop_words_ids
input_ids = torch . tensor ( [ tokenizer . encode ( prompt ) ] ) . to ( model . device )
input_ids = torch . tensor ( [ tokenizer . encode ( prompt ) ] ) . to ( model . device )
output = model . generate ( input_ids , stop_words_ids = stop_words_ids , * * gen_kwargs ) . tolist ( ) [ 0 ]
output = model . generate ( input_ids ,
output = tokenizer . decode ( output , errors = " ignore " )
stop_words_ids = stop_words_ids ,
* * gen_kwargs ) . tolist ( ) [ 0 ]
output = tokenizer . decode ( output , errors = ' ignore ' )
assert output . startswith ( prompt )
assert output . startswith ( prompt )
output = output [ len ( prompt ) : ]
output = output [ len ( prompt ) : ]
output = trim_stop_words ( output , [ " <|endoftext|> " , im_end ] )
output = trim_stop_words ( output , [ ' <|endoftext|> ' , im_end ] )
print ( f " <completion> \n { prompt } \n <!-- *** --> \n { output } \n </completion> " )
print ( f ' <completion> \n { prompt } \n <!-- *** --> \n { output } \n </completion> ' )
return output
return output
@app.post ( " /v1/chat/completions " , response_model = ChatCompletionResponse )
@app.post ( ' /v1/chat/completions ' , response_model = ChatCompletionResponse )
async def create_chat_completion ( request : ChatCompletionRequest ) :
async def create_chat_completion ( request : ChatCompletionRequest ) :
global model , tokenizer
global model , tokenizer
gen_kwargs = { }
gen_kwargs = { }
if request . top_k is not None :
gen_kwargs [ ' top_k ' ] = request . top_k
if request . temperature is not None :
if request . temperature is not None :
if request . temperature < 0.01 :
if request . temperature < 0.01 :
gen_kwargs [ ' top_k ' ] = 1 # greedy decoding
gen_kwargs [ ' top_k ' ] = 1 # greedy decoding
@ -395,32 +401,46 @@ async def create_chat_completion(request: ChatCompletionRequest):
stop_words = add_extra_stop_words ( request . stop )
stop_words = add_extra_stop_words ( request . stop )
if request . functions :
if request . functions :
stop_words = stop_words or [ ]
stop_words = stop_words or [ ]
if " Observation: " not in stop_words :
if ' Observation: ' not in stop_words :
stop_words . append ( " Observation: " )
stop_words . append ( ' Observation: ' )
query , history = parse_messages ( request . messages , request . functions )
query , history , system = parse_messages ( request . messages ,
request . functions )
if request . stream :
if request . stream :
if request . functions :
if request . functions :
raise HTTPException (
raise HTTPException (
status_code = 400 ,
status_code = 400 ,
detail = " Invalid request: Function calling is not yet implemented for stream mode. " ,
detail =
)
' Invalid request: Function calling is not yet implemented for stream mode. ' ,
generate = predict ( query , history , request . model , stop_words , gen_kwargs )
)
return EventSourceResponse ( generate , media_type = " text/event-stream " )
generate = predict ( query ,
history ,
stop_words_ids = [ tokenizer . encode ( s ) for s in stop_words ] if stop_words else None
request . model ,
stop_words ,
gen_kwargs ,
system = system )
return EventSourceResponse ( generate , media_type = ' text/event-stream ' )
stop_words_ids = [ tokenizer . encode ( s )
for s in stop_words ] if stop_words else None
if query is _TEXT_COMPLETION_CMD :
if query is _TEXT_COMPLETION_CMD :
response = text_complete_last_message ( history , stop_words_ids = stop_words_ids , gen_kwargs = gen_kwargs )
response = text_complete_last_message ( history ,
stop_words_ids = stop_words_ids ,
gen_kwargs = gen_kwargs ,
system = system )
else :
else :
response , _ = model . chat (
response , _ = model . chat (
tokenizer ,
tokenizer ,
query ,
query ,
history = history ,
history = history ,
system = system ,
stop_words_ids = stop_words_ids ,
stop_words_ids = stop_words_ids ,
* * gen_kwargs
* * gen_kwargs ,
)
)
print ( f " <chat> \n { history } \n { query } \n <!-- *** --> \n { response } \n </chat> " )
print ( ' <chat> ' )
pprint ( history , indent = 2 )
print ( f ' { query } \n <!-- *** --> \n { response } \n </chat> ' )
_gc ( )
_gc ( )
response = trim_stop_words ( response , stop_words )
response = trim_stop_words ( response , stop_words )
@ -429,12 +449,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
else :
else :
choice_data = ChatCompletionResponseChoice (
choice_data = ChatCompletionResponseChoice (
index = 0 ,
index = 0 ,
message = ChatMessage ( role = " assistant " , content = response ) ,
message = ChatMessage ( role = ' assistant ' , content = response ) ,
finish_reason = " stop " ,
finish_reason = ' stop ' ,
)
return ChatCompletionResponse (
model = request . model , choices = [ choice_data ] , object = " chat.completion "
)
)
return ChatCompletionResponse ( model = request . model ,
choices = [ choice_data ] ,
object = ' chat.completion ' )
def _dump_json ( data : BaseModel , * args , * * kwargs ) - > str :
def _dump_json ( data : BaseModel , * args , * * kwargs ) - > str :
@ -445,28 +465,37 @@ def _dump_json(data: BaseModel, *args, **kwargs) -> str:
async def predict (
async def predict (
query : str , history : List [ List [ str ] ] , model_id : str , stop_words : List [ str ] , gen_kwargs : Dict ,
query : str ,
history : List [ List [ str ] ] ,
model_id : str ,
stop_words : List [ str ] ,
gen_kwargs : Dict ,
system : str ,
) :
) :
global model , tokenizer
global model , tokenizer
choice_data = ChatCompletionResponseStreamChoice (
choice_data = ChatCompletionResponseStreamChoice (
index = 0 , delta = DeltaMessage ( role = " assistant " ) , finish_reason = None
index = 0 , delta = DeltaMessage ( role = ' assistant ' ) , finish_reason = None )
)
chunk = ChatCompletionResponse ( model = model_id ,
chunk = ChatCompletionResponse (
choices = [ choice_data ] ,
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
object = ' chat.completion.chunk ' )
)
yield ' {} ' . format ( _dump_json ( chunk , exclude_unset = True ) )
yield " {} " . format ( _dump_json ( chunk , exclude_unset = True ) )
current_length = 0
current_length = 0
stop_words_ids = [ tokenizer . encode ( s ) for s in stop_words ] if stop_words else None
stop_words_ids = [ tokenizer . encode ( s )
for s in stop_words ] if stop_words else None
if stop_words :
if stop_words :
# TODO: It's a little bit tricky to trim stop words in the stream mode.
# TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException (
raise HTTPException (
status_code = 400 ,
status_code = 400 ,
detail = " Invalid request: custom stop words are not yet supported for stream mode. " ,
detail =
)
' Invalid request: custom stop words are not yet supported for stream mode. ' ,
response_generator = model . chat_stream (
tokenizer , query , history = history , stop_words_ids = stop_words_ids , * * gen_kwargs
)
)
response_generator = model . chat_stream ( tokenizer ,
query ,
history = history ,
stop_words_ids = stop_words_ids ,
system = system ,
* * gen_kwargs )
for new_response in response_generator :
for new_response in response_generator :
if len ( new_response ) == current_length :
if len ( new_response ) == current_length :
continue
continue
@ -475,21 +504,20 @@ async def predict(
current_length = len ( new_response )
current_length = len ( new_response )
choice_data = ChatCompletionResponseStreamChoice (
choice_data = ChatCompletionResponseStreamChoice (
index = 0 , delta = DeltaMessage ( content = new_text ) , finish_reason = None
index = 0 , delta = DeltaMessage ( content = new_text ) , finish_reason = None )
)
chunk = ChatCompletionResponse ( model = model_id ,
chunk = ChatCompletionResponse (
choices = [ choice_data ] ,
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
object = ' chat.completion.chunk ' )
)
yield ' {} ' . format ( _dump_json ( chunk , exclude_unset = True ) )
yield " {} " . format ( _dump_json ( chunk , exclude_unset = True ) )
choice_data = ChatCompletionResponseStreamChoice ( index = 0 ,
choice_data = ChatCompletionResponseStreamChoice (
delta = DeltaMessage ( ) ,
index = 0 , delta = DeltaMessage ( ) , finish_reason = " stop "
finish_reason = ' stop ' )
)
chunk = ChatCompletionResponse ( model = model_id ,
chunk = ChatCompletionResponse (
choices = [ choice_data ] ,
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
object = ' chat.completion.chunk ' )
)
yield ' {} ' . format ( _dump_json ( chunk , exclude_unset = True ) )
yield " {} " . format ( _dump_json ( chunk , exclude_unset = True ) )
yield ' [DONE] '
yield " [DONE] "
_gc ( )
_gc ( )
@ -497,36 +525,39 @@ async def predict(
def _get_args ( ) :
def _get_args ( ) :
parser = ArgumentParser ( )
parser = ArgumentParser ( )
parser . add_argument (
parser . add_argument (
" -c " ,
' -c ' ,
" --checkpoint-path " ,
' --checkpoint-path ' ,
type = str ,
type = str ,
default = " Qwen/Qwen-7B-Chat " ,
default = ' Qwen/Qwen-7B-Chat ' ,
help = " Checkpoint name or path, default to %(default)r " ,
help = ' Checkpoint name or path, default to %(default)r ' ,
)
)
parser . add_argument ( ' --api-auth ' , help = ' API authentication credentials ' )
parser . add_argument ( ' --cpu-only ' ,
action = ' store_true ' ,
help = ' Run demo with CPU only ' )
parser . add_argument ( ' --server-port ' ,
type = int ,
default = 8000 ,
help = ' Demo server port. ' )
parser . add_argument (
parser . add_argument (
" --api-auth " , help = " API authentication credentials "
' --server-name ' ,
)
type = str ,
parser . add_argument (
default = ' 127.0.0.1 ' ,
" --cpu-only " , action = " store_true " , help = " Run demo with CPU only "
help =
)
' Demo server name. Default: 127.0.0.1, which is only visible from the local computer. '
parser . add_argument (
' If you want other computers to access your server, use 0.0.0.0 instead. ' ,
" --server-port " , type = int , default = 8000 , help = " Demo server port. "
)
)
parser . add_argument (
parser . add_argument (
" --server-name " ,
' --disable-gc ' ,
type = str ,
action = ' store_true ' ,
default = " 127.0.0.1 " ,
help = ' Disable GC after each response generated. ' ,
help = " Demo server name. Default: 127.0.0.1, which is only visible from the local computer. "
" If you want other computers to access your server, use 0.0.0.0 instead. " ,
)
)
parser . add_argument ( " --disable-gc " , action = " store_true " ,
help = " Disable GC after each response generated. " )
args = parser . parse_args ( )
args = parser . parse_args ( )
return args
return args
if __name__ == " __main__ " :
if __name__ == ' __main__ ' :
args = _get_args ( )
args = _get_args ( )
tokenizer = AutoTokenizer . from_pretrained (
tokenizer = AutoTokenizer . from_pretrained (
@ -536,14 +567,14 @@ if __name__ == "__main__":
)
)
if args . api_auth :
if args . api_auth :
app . add_middleware (
app . add_middleware ( BasicAuthMiddleware ,
BasicAuthMiddleware , username = args . api_auth . split ( " : " ) [ 0 ] , password = args . api_auth . split ( " : " ) [ 1 ]
username = args . api_auth . split ( ' : ' ) [ 0 ] ,
)
password = args . api_auth . split ( ' : ' ) [ 1 ] )
if args . cpu_only :
if args . cpu_only :
device_map = " cpu "
device_map = ' cpu '
else :
else :
device_map = " auto "
device_map = ' auto '
model = AutoModelForCausalLM . from_pretrained (
model = AutoModelForCausalLM . from_pretrained (
args . checkpoint_path ,
args . checkpoint_path ,