@ -1,14 +1,16 @@
# coding=utf-8
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Requirement:
# pip install "openai<1.0"
# Usage:
# python openai_api.py
# Visit http://localhost:8000/docs for documents.
import re
import base64
import copy
import json
import time
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from pprint import pprint
from typing import Dict , List , Literal , Optional , Union
import torch
@ -17,20 +19,22 @@ from fastapi import FastAPI, HTTPException
from fastapi . middleware . cors import CORSMiddleware
from pydantic import BaseModel , Field
from sse_starlette . sse import EventSourceResponse
from transformers import AutoTokenizer , AutoModelForCausalLM
from transformers . generation import GenerationConfig
from starlette . middleware . base import BaseHTTPMiddleware
from starlette . requests import Request
from starlette . responses import Response
import base64
from transformers import AutoModelForCausalLM , AutoTokenizer
from transformers . generation import GenerationConfig
class BasicAuthMiddleware ( BaseHTTPMiddleware ) :
def __init__ ( self , app , username : str , password : str ) :
super ( ) . __init__ ( app )
self . required_credentials = base64 . b64encode ( f " { username } : { password } " . encode ( ) ) . decode ( )
self . required_credentials = base64 . b64encode (
f ' { username } : { password } ' . encode ( ) ) . decode ( )
async def dispatch ( self , request : Request , call_next ) :
authorization : str = request . headers . get ( " Authorization " )
authorization : str = request . headers . get ( ' Authorization ' )
if authorization :
try :
schema , credentials = authorization . split ( )
@ -38,16 +42,18 @@ class BasicAuthMiddleware(BaseHTTPMiddleware):
return await call_next ( request )
except ValueError :
pass
headers = { ' WWW-Authenticate ' : ' Basic ' }
return Response ( status_code = 401 , headers = headers )
def _gc ( forced : bool = False ) :
global args
if args . disable_gc and not forced :
return
import gc
gc . collect ( )
if torch . cuda . is_available ( ) :
torch . cuda . empty_cache ( )
@ -63,36 +69,36 @@ app = FastAPI(lifespan=lifespan)
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_origins = [ ' * ' ] ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
allow_methods = [ ' * ' ] ,
allow_headers = [ ' * ' ] ,
)
class ModelCard ( BaseModel ) :
id : str
object : str = " model "
object : str = ' model '
created : int = Field ( default_factory = lambda : int ( time . time ( ) ) )
owned_by : str = " owner "
owned_by : str = ' owner '
root : Optional [ str ] = None
parent : Optional [ str ] = None
permission : Optional [ list ] = None
class ModelList ( BaseModel ) :
object : str = " list "
object : str = ' list '
data : List [ ModelCard ] = [ ]
class ChatMessage ( BaseModel ) :
role : Literal [ " user " , " assistant " , " system " , " function " ]
role : Literal [ ' user ' , ' assistant ' , ' system ' , ' function ' ]
content : Optional [ str ]
function_call : Optional [ Dict ] = None
class DeltaMessage ( BaseModel ) :
role : Optional [ Literal [ " user " , " assistant " , " system " ] ] = None
role : Optional [ Literal [ ' user ' , ' assistant ' , ' system ' ] ] = None
content : Optional [ str ] = None
@ -102,6 +108,7 @@ class ChatCompletionRequest(BaseModel):
functions : Optional [ List [ Dict ] ] = None
temperature : Optional [ float ] = None
top_p : Optional [ float ] = None
top_k : Optional [ int ] = None
max_length : Optional [ int ] = None
stream : Optional [ bool ] = False
stop : Optional [ List [ str ] ] = None
@ -109,29 +116,28 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice ( BaseModel ) :
index : int
message : ChatMessage
finish_reason : Literal [ " stop " , " length " , " function_call " ]
message : Union[ ChatMessage]
finish_reason : Literal [ ' stop ' , ' length ' , ' function_call ' ]
class ChatCompletionResponseStreamChoice ( BaseModel ) :
index : int
delta : DeltaMessage
finish_reason : Optional [ Literal [ " stop " , " length " ] ]
finish_reason : Optional [ Literal [ ' stop ' , ' length ' ] ]
class ChatCompletionResponse ( BaseModel ) :
model : str
object : Literal [ " chat.completion " , " chat.completion.chunk " ]
choices : List [
Union [ ChatCompletionResponseChoice , ChatCompletionResponseStreamChoice ]
]
object : Literal [ ' chat.completion ' , ' chat.completion.chunk ' ]
choices : List [ Union [ ChatCompletionResponseChoice ,
ChatCompletionResponseStreamChoice ] ]
created : Optional [ int ] = Field ( default_factory = lambda : int ( time . time ( ) ) )
@app.get ( " /v1/models " , response_model = ModelList )
@app.get ( ' /v1/models ' , response_model = ModelList )
async def list_models ( ) :
global model_args
model_card = ModelCard ( id = " gpt-3.5-turbo " )
model_card = ModelCard ( id = ' gpt-3.5-turbo ' )
return ModelList ( data = [ model_card ] )
@ -141,7 +147,7 @@ def add_extra_stop_words(stop_words):
_stop_words = [ ]
_stop_words . extend ( stop_words )
for x in stop_words :
s = x . lstrip ( " \n " )
s = x . lstrip ( ' \n ' )
if s and ( s not in _stop_words ) :
_stop_words . append ( s )
return _stop_words
@ -157,7 +163,10 @@ def trim_stop_words(response, stop_words):
return response
TOOL_DESC = """ {name_for_model} : Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} """
TOOL_DESC = (
' {name_for_model} : Call this tool to interact with the {name_for_human} API. '
' What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} '
)
REACT_INSTRUCTION = """ Answer the following questions as best you can. You have access to the following APIs:
@ -179,37 +188,28 @@ Begin!"""
_TEXT_COMPLETION_CMD = object ( )
#
# Temporarily, the system role does not work as expected.
# We advise that you write the setups for role-play in your query,
# i.e., use the user role instead of the system role.
#
# TODO: Use real system role when the model is ready.
#
def parse_messages ( messages , functions ) :
if all ( m . role != " user " for m in messages ) :
if all ( m . role != ' user ' for m in messages ) :
raise HTTPException (
status_code = 400 ,
detail = f" Invalid request: Expecting at least one user message. " ,
detail = ' Invalid request: Expecting at least one user message. ' ,
)
messages = copy . deepcopy ( messages )
default_system = " You are a helpful assistant. "
system = " "
if messages [ 0 ] . role == " system " :
system = messages . pop ( 0 ) . content . lstrip ( " \n " ) . rstrip ( )
if system == default_system :
system = " "
if messages [ 0 ] . role == ' system ' :
system = messages . pop ( 0 ) . content . lstrip ( ' \n ' ) . rstrip ( )
else :
system = ' You are a helpful assistant. '
if functions :
tools_text = [ ]
tools_name_text = [ ]
for func_info in functions :
name = func_info . get ( " name " , " " )
name_m = func_info . get ( " name_for_model " , name )
name_h = func_info . get ( " name_for_human " , name )
desc = func_info . get ( " description " , " " )
desc_m = func_info . get ( " description_for_model " , desc )
name = func_info . get ( ' name ' , ' ' )
name_m = func_info . get ( ' name_for_model ' , name )
name_h = func_info . get ( ' name_for_human ' , name )
desc = func_info . get ( ' description ' , ' ' )
desc_m = func_info . get ( ' description_for_model ' , desc )
tool = TOOL_DESC . format (
name_for_model = name_m ,
name_for_human = name_h ,
@ -217,150 +217,152 @@ def parse_messages(messages, functions):
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model = desc_m ,
parameters = json . dumps ( func_info [ " parameters " ] , ensure_ascii = False ) ,
parameters = json . dumps ( func_info [ ' parameters ' ] ,
ensure_ascii = False ) ,
)
tools_text . append ( tool )
tools_name_text . append ( name_m )
tools_text = " \n \n " . join ( tools_text )
tools_name_text = " , " . join ( tools_name_text )
system + = " \n \n " + REACT_INSTRUCTION . format (
tools_text = ' \n \n ' . join ( tools_text )
tools_name_text = ' , ' . join ( tools_name_text )
instruction = ( REACT_INSTRUCTION . format (
tools_text = tools_text ,
tools_name_text = tools_name_text ,
)
system = system . lstrip ( " \n " ) . rstrip ( )
dummy_thought = {
" en " : " \n Thought: I now know the final answer. \n Final answer: " ,
" zh " : " \n Thought: 我会作答了。 \n Final answer: " ,
}
) . lstrip ( ' \n ' ) . rstrip ( ) )
else :
instruction = ' '
_ messages = messages
messages_with_fncall = messages
messages = [ ]
for m_idx , m in enumerate ( _ messages) :
for m_idx , m in enumerate ( messages_with_fncall ) :
role , content , func_call = m . role , m . content , m . function_call
if content :
content = content . lstrip ( " \n " ) . rstrip ( )
if role == " function " :
if ( len ( messages ) == 0 ) or ( messages [ - 1 ] . role != " assistant " ) :
content = content or ' '
content = content . lstrip ( ' \n ' ) . rstrip ( )
if role == ' function ' :
if ( len ( messages ) == 0 ) or ( messages [ - 1 ] . role != ' assistant ' ) :
raise HTTPException (
status_code = 400 ,
detail = f " Invalid request: Expecting role assistant before role function. " ,
detail =
' Invalid request: Expecting role assistant before role function. ' ,
)
messages [ - 1 ] . content + = f " \n Observation: { content } "
if m_idx == len ( _messages ) - 1 :
messages [ - 1 ] . content + = " \n Thought: "
elif role == " assistant " :
messages [ - 1 ] . content + = f ' \n Observation: { content } '
if m_idx == len ( messages_with_fncall ) - 1 :
# add a prefix for text completion
messages [ - 1 ] . content + = ' \n Thought: '
elif role == ' assistant ' :
if len ( messages ) == 0 :
raise HTTPException (
status_code = 400 ,
detail = f " Invalid request: Expecting role user before role assistant. " ,
detail =
' Invalid request: Expecting role user before role assistant. ' ,
)
last_msg = messages [ - 1 ] . content
last_msg_has_zh = len ( re . findall ( r " [ \ u4e00- \ u9fff]+ " , last_msg ) ) > 0
if func_call is None :
if functions :
content = dummy_thought [ " zh " if last_msg_has_zh else " en " ] + content
content = f ' Thought: I now know the final answer. \n Final Answer: { content } '
else :
f_name , f_args = func_call [ " name " ] , func_call [ " arguments " ]
if not content :
if last_msg_has_zh :
content = f " Thought: 我可以使用 { f_name } API。 "
else :
content = f " Thought: I can use { f_name } . "
content = f " \n { content } \n Action: { f_name } \n Action Input: { f_args } "
if messages [ - 1 ] . role == " user " :
f_name , f_args = func_call [ ' name ' ] , func_call [ ' arguments ' ]
if not content . startswith ( ' Thought: ' ) :
content = f ' Thought: { content } '
content = f ' { content } \n Action: { f_name } \n Action Input: { f_args } '
if messages [ - 1 ] . role == ' user ' :
messages . append (
ChatMessage ( role = " assistant " , content = content . lstrip ( " \n " ) . rstrip ( ) )
)
ChatMessage ( role = ' assistant ' ,
content = content . lstrip ( ' \n ' ) . rstrip ( ) ) )
else :
messages [ - 1 ] . content + = content
elif role == " user " :
messages [ - 1 ] . content + = ' \n ' + content
elif role == ' user ' :
messages . append (
ChatMessage ( role = " user " , content = content . lstrip ( " \n " ) . rstrip ( ) )
)
ChatMessage ( role = ' user ' ,
content = content . lstrip ( ' \n ' ) . rstrip ( ) ) )
else :
raise HTTPException (
status_code = 400 , detail = f " Invalid request: Incorrect role { role } . "
)
status_code = 400 ,
detail = f ' Invalid request: Incorrect role { role } . ' )
query = _TEXT_COMPLETION_CMD
if messages [ - 1 ] . role == " user " :
if messages [ - 1 ] . role == ' user ' :
query = messages [ - 1 ] . content
messages = messages [ : - 1 ]
if len ( messages ) % 2 != 0 :
raise HTTPException ( status_code = 400 , detail = " Invalid request " )
raise HTTPException ( status_code = 400 , detail = ' Invalid request ' )
history = [ ] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range ( 0 , len ( messages ) , 2 ) :
if messages [ i ] . role == " user " and messages [ i + 1 ] . role == " assistant " :
usr_msg = messages [ i ] . content . lstrip ( " \n " ) . rstrip ( )
bot_msg = messages [ i + 1 ] . content . lstrip ( " \n " ) . rstrip ( )
if system and ( i == len ( messages ) - 2 ) :
usr_msg = f " { system } \n \n Question: { usr_msg } "
system = " "
for t in dummy_thought . values ( ) :
t = t . lstrip ( " \n " )
if bot_msg . startswith ( t ) and ( " \n Action: " in bot_msg ) :
bot_msg = bot_msg [ len ( t ) : ]
if messages [ i ] . role == ' user ' and messages [ i + 1 ] . role == ' assistant ' :
usr_msg = messages [ i ] . content . lstrip ( ' \n ' ) . rstrip ( )
bot_msg = messages [ i + 1 ] . content . lstrip ( ' \n ' ) . rstrip ( )
if instruction and ( i == len ( messages ) - 2 ) :
usr_msg = f ' { instruction } \n \n Question: { usr_msg } '
instruction = ' '
history . append ( [ usr_msg , bot_msg ] )
else :
raise HTTPException (
status_code = 400 ,
detail = " Invalid request: Expecting exactly one user (or function) role before every assistant role. " ,
detail =
' Invalid request: Expecting exactly one user (or function) role before every assistant role. ' ,
)
if system :
if instruction :
assert query is not _TEXT_COMPLETION_CMD
query = f " { system } \n \n Question: { query } "
return query , history
query = f ' { instruction } \n \n Question: { query } '
return query , history , system
def parse_response ( response ) :
func_name , func_args = " " , " "
i = response . rfind( " \n Action: " )
j = response . rfind( " \n Action Input: " )
k = response . rfind( " \n Observation: " )
func_name , func_args = ' ' , ' '
i = response . find( ' \n Action: ' )
j = response . find( ' \n Action Input: ' )
k = response . find( ' \n Observation: ' )
if 0 < = i < j : # If the text has `Action` and `Action input`,
if k < j : # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
response = response . rstrip ( ) + " \n Observation: " # Add it back.
k = response . rfind ( " \n Observation: " )
func_name = response [ i + len ( " \n Action: " ) : j ] . strip ( )
func_args = response [ j + len ( " \n Action Input: " ) : k ] . strip ( )
response = response . rstrip ( ) + ' \n Observation: ' # Add it back.
k = response . find ( ' \n Observation: ' )
func_name = response [ i + len ( ' \n Action: ' ) : j ] . strip ( )
func_args = response [ j + len ( ' \n Action Input: ' ) : k ] . strip ( )
if func_name :
response = response [ : i ]
t = response . find ( ' Thought: ' )
if t > = 0 :
response = response [ t + len ( ' Thought: ' ) : ]
response = response . strip ( )
choice_data = ChatCompletionResponseChoice (
index = 0 ,
message = ChatMessage (
role = " assistant " ,
content = response [ : i ] ,
function_call = { " name " : func_name , " arguments " : func_args } ,
role = ' assistant ' ,
content = response ,
function_call = {
' name ' : func_name ,
' arguments ' : func_args
} ,
) ,
finish_reason = " function_call " ,
finish_reason = ' function_call ' ,
)
return choice_data
z = response . rfind ( " \n Final Answer: " )
z = response . rfind ( ' \n Final Answer: ' )
if z > = 0 :
response = response [ z + len ( " \n Final Answer: " ) : ]
response = response [ z + len ( ' \n Final Answer: ' ) : ]
choice_data = ChatCompletionResponseChoice (
index = 0 ,
message = ChatMessage ( role = " assistant " , content = response ) ,
finish_reason = " stop " ,
message = ChatMessage ( role = ' assistant ' , content = response ) ,
finish_reason = ' stop ' ,
)
return choice_data
# completion mode, not chat mode
def text_complete_last_message ( history , stop_words_ids , gen_kwargs ):
im_start = " <|im_start|> "
im_end = " <|im_end|> "
prompt = f " { im_start } system \n You are a helpful assistant. { im_end } "
def text_complete_last_message ( history , stop_words_ids , gen_kwargs , system ):
im_start = ' <|im_start|> '
im_end = ' <|im_end|> '
prompt = f ' { im_start } system \n { system } { im_end } '
for i , ( query , response ) in enumerate ( history ) :
query = query . lstrip ( " \n " ) . rstrip ( )
response = response . lstrip ( " \n " ) . rstrip ( )
prompt + = f " \n { im_start } user \n { query } { im_end } "
prompt + = f " \n { im_start } assistant \n { response } { im_end } "
prompt = prompt [ : - len ( im_end ) ]
query = query . lstrip ( ' \n ' ) . rstrip ( )
response = response . lstrip ( ' \n ' ) . rstrip ( )
prompt + = f ' \n { im_start } user \n { query } { im_end } '
prompt + = f ' \n { im_start } assistant \n { response } { im_end } '
prompt = prompt [ : - len ( im_end ) ]
_stop_words_ids = [ tokenizer . encode ( im_end ) ]
if stop_words_ids :
@ -369,20 +371,24 @@ def text_complete_last_message(history, stop_words_ids, gen_kwargs):
stop_words_ids = _stop_words_ids
input_ids = torch . tensor ( [ tokenizer . encode ( prompt ) ] ) . to ( model . device )
output = model . generate ( input_ids , stop_words_ids = stop_words_ids , * * gen_kwargs ) . tolist ( ) [ 0 ]
output = tokenizer . decode ( output , errors = " ignore " )
output = model . generate ( input_ids ,
stop_words_ids = stop_words_ids ,
* * gen_kwargs ) . tolist ( ) [ 0 ]
output = tokenizer . decode ( output , errors = ' ignore ' )
assert output . startswith ( prompt )
output = output [ len ( prompt ) : ]
output = trim_stop_words ( output , [ " <|endoftext|> " , im_end ] )
print ( f " <completion> \n { prompt } \n <!-- *** --> \n { output } \n </completion> " )
output = output [ len ( prompt ) : ]
output = trim_stop_words ( output , [ ' <|endoftext|> ' , im_end ] )
print ( f ' <completion> \n { prompt } \n <!-- *** --> \n { output } \n </completion> ' )
return output
@app.post ( " /v1/chat/completions " , response_model = ChatCompletionResponse )
@app.post ( ' /v1/chat/completions ' , response_model = ChatCompletionResponse )
async def create_chat_completion ( request : ChatCompletionRequest ) :
global model , tokenizer
gen_kwargs = { }
if request . top_k is not None :
gen_kwargs [ ' top_k ' ] = request . top_k
if request . temperature is not None :
if request . temperature < 0.01 :
gen_kwargs [ ' top_k ' ] = 1 # greedy decoding
@ -395,32 +401,46 @@ async def create_chat_completion(request: ChatCompletionRequest):
stop_words = add_extra_stop_words ( request . stop )
if request . functions :
stop_words = stop_words or [ ]
if " Observation: " not in stop_words :
stop_words . append ( " Observation: " )
if ' Observation: ' not in stop_words :
stop_words . append ( ' Observation: ' )
query , history = parse_messages ( request . messages , request . functions )
query , history , system = parse_messages ( request . messages ,
request . functions )
if request . stream :
if request . functions :
raise HTTPException (
status_code = 400 ,
detail = " Invalid request: Function calling is not yet implemented for stream mode. " ,
detail =
' Invalid request: Function calling is not yet implemented for stream mode. ' ,
)
generate = predict ( query , history , request . model , stop_words , gen_kwargs )
return EventSourceResponse ( generate , media_type = " text/event-stream " )
stop_words_ids = [ tokenizer . encode ( s ) for s in stop_words ] if stop_words else None
generate = predict ( query ,
history ,
request . model ,
stop_words ,
gen_kwargs ,
system = system )
return EventSourceResponse ( generate , media_type = ' text/event-stream ' )
stop_words_ids = [ tokenizer . encode ( s )
for s in stop_words ] if stop_words else None
if query is _TEXT_COMPLETION_CMD :
response = text_complete_last_message ( history , stop_words_ids = stop_words_ids , gen_kwargs = gen_kwargs )
response = text_complete_last_message ( history ,
stop_words_ids = stop_words_ids ,
gen_kwargs = gen_kwargs ,
system = system )
else :
response , _ = model . chat (
tokenizer ,
query ,
history = history ,
system = system ,
stop_words_ids = stop_words_ids ,
* * gen_kwargs
* * gen_kwargs ,
)
print ( f " <chat> \n { history } \n { query } \n <!-- *** --> \n { response } \n </chat> " )
print ( ' <chat> ' )
pprint ( history , indent = 2 )
print ( f ' { query } \n <!-- *** --> \n { response } \n </chat> ' )
_gc ( )
response = trim_stop_words ( response , stop_words )
@ -429,12 +449,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
else :
choice_data = ChatCompletionResponseChoice (
index = 0 ,
message = ChatMessage ( role = " assistant " , content = response ) ,
finish_reason = " stop " ,
message = ChatMessage ( role = ' assistant ' , content = response ) ,
finish_reason = ' stop ' ,
)
return ChatCompletionResponse (
model = request . model , choices = [ choice_data ] , object = " chat.completion "
)
return ChatCompletionResponse ( model = request . model ,
choices = [ choice_data ] ,
object = ' chat.completion ' )
def _dump_json ( data : BaseModel , * args , * * kwargs ) - > str :
@ -445,28 +465,37 @@ def _dump_json(data: BaseModel, *args, **kwargs) -> str:
async def predict (
query : str , history : List [ List [ str ] ] , model_id : str , stop_words : List [ str ] , gen_kwargs : Dict ,
query : str ,
history : List [ List [ str ] ] ,
model_id : str ,
stop_words : List [ str ] ,
gen_kwargs : Dict ,
system : str ,
) :
global model , tokenizer
choice_data = ChatCompletionResponseStreamChoice (
index = 0 , delta = DeltaMessage ( role = " assistant " ) , finish_reason = None
)
chunk = ChatCompletionResponse (
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
)
yield " {} " . format ( _dump_json ( chunk , exclude_unset = True ) )
index = 0 , delta = DeltaMessage ( role = ' assistant ' ) , finish_reason = None )
chunk = ChatCompletionResponse ( model = model_id ,
choices = [ choice_data ] ,
object = ' chat.completion.chunk ' )
yield ' {} ' . format ( _dump_json ( chunk , exclude_unset = True ) )
current_length = 0
stop_words_ids = [ tokenizer . encode ( s ) for s in stop_words ] if stop_words else None
stop_words_ids = [ tokenizer . encode ( s )
for s in stop_words ] if stop_words else None
if stop_words :
# TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException (
status_code = 400 ,
detail = " Invalid request: custom stop words are not yet supported for stream mode. " ,
detail =
' Invalid request: custom stop words are not yet supported for stream mode. ' ,
)
response_generator = model . chat_stream (
tokenizer , query , history = history , stop_words_ids = stop_words_ids , * * gen_kwargs
)
response_generator = model . chat_stream ( tokenizer ,
query ,
history = history ,
stop_words_ids = stop_words_ids ,
system = system ,
* * gen_kwargs )
for new_response in response_generator :
if len ( new_response ) == current_length :
continue
@ -475,21 +504,20 @@ async def predict(
current_length = len ( new_response )
choice_data = ChatCompletionResponseStreamChoice (
index = 0 , delta = DeltaMessage ( content = new_text ) , finish_reason = None
)
chunk = ChatCompletionResponse (
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
)
yield " {} " . format ( _dump_json ( chunk , exclude_unset = True ) )
choice_data = ChatCompletionResponseStreamChoice (
index = 0 , delta = DeltaMessage ( ) , finish_reason = " stop "
)
chunk = ChatCompletionResponse (
model = model_id , choices = [ choice_data ] , object = " chat.completion.chunk "
)
yield " {} " . format ( _dump_json ( chunk , exclude_unset = True ) )
yield " [DONE] "
index = 0 , delta = DeltaMessage ( content = new_text ) , finish_reason = None )
chunk = ChatCompletionResponse ( model = model_id ,
choices = [ choice_data ] ,
object = ' chat.completion.chunk ' )
yield ' {} ' . format ( _dump_json ( chunk , exclude_unset = True ) )
choice_data = ChatCompletionResponseStreamChoice ( index = 0 ,
delta = DeltaMessage ( ) ,
finish_reason = ' stop ' )
chunk = ChatCompletionResponse ( model = model_id ,
choices = [ choice_data ] ,
object = ' chat.completion.chunk ' )
yield ' {} ' . format ( _dump_json ( chunk , exclude_unset = True ) )
yield ' [DONE] '
_gc ( )
@ -497,36 +525,39 @@ async def predict(
def _get_args ( ) :
parser = ArgumentParser ( )
parser . add_argument (
" -c " ,
" --checkpoint-path " ,
' -c ' ,
' --checkpoint-path ' ,
type = str ,
default = " Qwen/Qwen-7B-Chat " ,
help = " Checkpoint name or path, default to %(default)r " ,
default = ' Qwen/Qwen-7B-Chat ' ,
help = ' Checkpoint name or path, default to %(default)r ' ,
)
parser . add_argument ( ' --api-auth ' , help = ' API authentication credentials ' )
parser . add_argument ( ' --cpu-only ' ,
action = ' store_true ' ,
help = ' Run demo with CPU only ' )
parser . add_argument ( ' --server-port ' ,
type = int ,
default = 8000 ,
help = ' Demo server port. ' )
parser . add_argument (
" --api-auth " , help = " API authentication credentials "
)
parser . add_argument (
" --cpu-only " , action = " store_true " , help = " Run demo with CPU only "
)
parser . add_argument (
" --server-port " , type = int , default = 8000 , help = " Demo server port. "
' --server-name ' ,
type = str ,
default = ' 127.0.0.1 ' ,
help =
' Demo server name. Default: 127.0.0.1, which is only visible from the local computer. '
' If you want other computers to access your server, use 0.0.0.0 instead. ' ,
)
parser . add_argument (
" --server-name " ,
type = str ,
default = " 127.0.0.1 " ,
help = " Demo server name. Default: 127.0.0.1, which is only visible from the local computer. "
" If you want other computers to access your server, use 0.0.0.0 instead. " ,
' --disable-gc ' ,
action = ' store_true ' ,
help = ' Disable GC after each response generated. ' ,
)
parser . add_argument ( " --disable-gc " , action = " store_true " ,
help = " Disable GC after each response generated. " )
args = parser . parse_args ( )
return args
if __name__ == " __main__ " :
if __name__ == ' __main__ ' :
args = _get_args ( )
tokenizer = AutoTokenizer . from_pretrained (
@ -536,14 +567,14 @@ if __name__ == "__main__":
)
if args . api_auth :
app . add_middleware (
BasicAuthMiddleware , username = args . api_auth . split ( " : " ) [ 0 ] , password = args . api_auth . split ( " : " ) [ 1 ]
)
app . add_middleware ( BasicAuthMiddleware ,
username = args . api_auth . split ( ' : ' ) [ 0 ] ,
password = args . api_auth . split ( ' : ' ) [ 1 ] )
if args . cpu_only :
device_map = " cpu "
device_map = ' cpu '
else :
device_map = " auto "
device_map = ' auto '
model = AutoModelForCausalLM . from_pretrained (
args . checkpoint_path ,