add stop word on openai api ChatCompletion

main
cyente 1 year ago
parent e494489bf1
commit a3a5b3de47

@ -319,6 +319,7 @@ for chunk in openai.ChatCompletion.create(
{"role": "user", "content": "你好"} {"role": "user", "content": "你好"}
], ],
stream=True stream=True
# Specifying stop words in streaming output format is not yet supported and is under development.
): ):
if hasattr(chunk.choices[0].delta, "content"): if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, end="", flush=True) print(chunk.choices[0].delta.content, end="", flush=True)
@ -329,7 +330,8 @@ response = openai.ChatCompletion.create(
messages=[ messages=[
{"role": "user", "content": "你好"} {"role": "user", "content": "你好"}
], ],
stream=False stream=False,
stop=[] # You can add custom stop words here, e.g., stop=["Observation:"] for ReAct prompting.
) )
print(response.choices[0].message.content) print(response.choices[0].message.content)
``` ```

@ -323,6 +323,7 @@ for chunk in openai.ChatCompletion.create(
{"role": "user", "content": "你好"} {"role": "user", "content": "你好"}
], ],
stream=True stream=True
# 流式输出的自定义stopwords功能尚未支持正在开发中
): ):
if hasattr(chunk.choices[0].delta, "content"): if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, end="", flush=True) print(chunk.choices[0].delta.content, end="", flush=True)
@ -333,7 +334,8 @@ response = openai.ChatCompletion.create(
messages=[ messages=[
{"role": "user", "content": "你好"} {"role": "user", "content": "你好"}
], ],
stream=False stream=False,
stop=[] # 在此处添加自定义的stop words 例如ReAct prompting时需要增加 stop=["Observation:"]。
) )
print(response.choices[0].message.content) print(response.choices[0].message.content)
``` ```

@ -68,6 +68,7 @@ class ChatCompletionRequest(BaseModel):
top_p: Optional[float] = None top_p: Optional[float] = None
max_length: Optional[int] = None max_length: Optional[int] = None
stream: Optional[bool] = False stream: Optional[bool] = False
stop: Optional[List[str]] = []
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
@ -103,7 +104,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user": if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
stop_words = request.stop
stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words))))
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query. # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
# if len(prev_messages) > 0 and prev_messages[0].role == "system": # if len(prev_messages) > 0 and prev_messages[0].role == "system":
@ -120,10 +122,18 @@ async def create_chat_completion(request: ChatCompletionRequest):
raise HTTPException(status_code=400, detail="Invalid request.") raise HTTPException(status_code=400, detail="Invalid request.")
if request.stream: if request.stream:
generate = predict(query, history, request.model) generate = predict(query, history, request.model, stop_words)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
if stop_words:
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
for stop_ in stop_words:
if response.endswith(stop_):
response = response[:response.find(stop_)]
else:
response, _ = model.chat(tokenizer, query, history=history) response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role="assistant", content=response),
@ -133,9 +143,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str): async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]):
global model, tokenizer global model, tokenizer
assert stop_words == [], "in stream format, stop word is output"
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
@ -145,8 +155,13 @@ async def predict(query: str, history: List[List[str]], model_id: str):
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "{}".format(chunk.model_dump_json(exclude_unset=True))
current_length = 0 current_length = 0
if stop_words:
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
else:
response_generator = model.chat_stream(tokenizer, query, history=history)
for new_response in model.chat_stream(tokenizer, query, history): for new_response in response_generator:
if len(new_response) == current_length: if len(new_response) == current_length:
continue continue

Loading…
Cancel
Save