|
|
@ -68,6 +68,7 @@ class ChatCompletionRequest(BaseModel):
|
|
|
|
top_p: Optional[float] = None
|
|
|
|
top_p: Optional[float] = None
|
|
|
|
max_length: Optional[int] = None
|
|
|
|
max_length: Optional[int] = None
|
|
|
|
stream: Optional[bool] = False
|
|
|
|
stream: Optional[bool] = False
|
|
|
|
|
|
|
|
stop: Optional[List[str]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompletionResponseChoice(BaseModel):
|
|
|
|
class ChatCompletionResponseChoice(BaseModel):
|
|
|
@ -103,7 +104,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
if request.messages[-1].role != "user":
|
|
|
|
if request.messages[-1].role != "user":
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid request")
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid request")
|
|
|
|
query = request.messages[-1].content
|
|
|
|
query = request.messages[-1].content
|
|
|
|
|
|
|
|
stop_words = request.stop
|
|
|
|
|
|
|
|
stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words))))
|
|
|
|
prev_messages = request.messages[:-1]
|
|
|
|
prev_messages = request.messages[:-1]
|
|
|
|
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
|
|
|
|
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
|
|
|
|
# if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
|
|
|
# if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
|
|
@ -120,10 +122,18 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid request.")
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid request.")
|
|
|
|
|
|
|
|
|
|
|
|
if request.stream:
|
|
|
|
if request.stream:
|
|
|
|
generate = predict(query, history, request.model)
|
|
|
|
generate = predict(query, history, request.model, stop_words)
|
|
|
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
|
|
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stop_words:
|
|
|
|
|
|
|
|
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
|
|
|
|
|
|
|
|
response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
|
|
|
|
|
|
|
|
for stop_ in stop_words:
|
|
|
|
|
|
|
|
if response.endswith(stop_):
|
|
|
|
|
|
|
|
response = response[:response.find(stop_)]
|
|
|
|
|
|
|
|
else:
|
|
|
|
response, _ = model.chat(tokenizer, query, history=history)
|
|
|
|
response, _ = model.chat(tokenizer, query, history=history)
|
|
|
|
|
|
|
|
|
|
|
|
choice_data = ChatCompletionResponseChoice(
|
|
|
|
choice_data = ChatCompletionResponseChoice(
|
|
|
|
index=0,
|
|
|
|
index=0,
|
|
|
|
message=ChatMessage(role="assistant", content=response),
|
|
|
|
message=ChatMessage(role="assistant", content=response),
|
|
|
@ -133,9 +143,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
|
|
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def predict(query: str, history: List[List[str]], model_id: str):
|
|
|
|
async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]):
|
|
|
|
global model, tokenizer
|
|
|
|
global model, tokenizer
|
|
|
|
|
|
|
|
assert stop_words == [], "in stream format, stop word is output"
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
index=0,
|
|
|
|
index=0,
|
|
|
|
delta=DeltaMessage(role="assistant"),
|
|
|
|
delta=DeltaMessage(role="assistant"),
|
|
|
@ -145,8 +155,13 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|
|
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
|
|
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
|
|
|
|
|
|
|
|
|
|
|
current_length = 0
|
|
|
|
current_length = 0
|
|
|
|
|
|
|
|
if stop_words:
|
|
|
|
|
|
|
|
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
|
|
|
|
|
|
|
|
response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
response_generator = model.chat_stream(tokenizer, query, history=history)
|
|
|
|
|
|
|
|
|
|
|
|
for new_response in model.chat_stream(tokenizer, query, history):
|
|
|
|
for new_response in response_generator:
|
|
|
|
if len(new_response) == current_length:
|
|
|
|
if len(new_response) == current_length:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|