From a3a5b3de4774888e2f605129806226405070b44e Mon Sep 17 00:00:00 2001 From: cyente Date: Wed, 23 Aug 2023 16:30:53 +0800 Subject: [PATCH] add stop word on openai api ChatCompletion --- README.md | 6 ++++-- README_CN.md | 4 +++- openai_api.py | 27 +++++++++++++++++++++------ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1205e3b..63cc127 100644 --- a/README.md +++ b/README.md @@ -318,7 +318,8 @@ for chunk in openai.ChatCompletion.create( messages=[ {"role": "user", "content": "你好"} ], - stream=True + stream=True + # Specifying stop words in streaming output format is not yet supported and is under development. ): if hasattr(chunk.choices[0].delta, "content"): print(chunk.choices[0].delta.content, end="", flush=True) @@ -329,7 +330,8 @@ response = openai.ChatCompletion.create( messages=[ {"role": "user", "content": "你好"} ], - stream=False + stream=False, + stop=[] # You can add custom stop words here, e.g., stop=["Observation:"] for ReAct prompting. ) print(response.choices[0].message.content) ``` diff --git a/README_CN.md b/README_CN.md index 811d5e2..4764fe4 100644 --- a/README_CN.md +++ b/README_CN.md @@ -323,6 +323,7 @@ for chunk in openai.ChatCompletion.create( {"role": "user", "content": "你好"} ], stream=True + # 流式输出的自定义stopwords功能尚未支持,正在开发中 ): if hasattr(chunk.choices[0].delta, "content"): print(chunk.choices[0].delta.content, end="", flush=True) @@ -333,7 +334,8 @@ response = openai.ChatCompletion.create( messages=[ {"role": "user", "content": "你好"} ], - stream=False + stream=False, + stop=[] # 在此处添加自定义的stop words 例如ReAct prompting时需要增加: stop=["Observation:"]。 ) print(response.choices[0].message.content) ``` diff --git a/openai_api.py b/openai_api.py index da105f3..52da00b 100644 --- a/openai_api.py +++ b/openai_api.py @@ -68,6 +68,7 @@ class ChatCompletionRequest(BaseModel): top_p: Optional[float] = None max_length: Optional[int] = None stream: Optional[bool] = False + stop: Optional[List[str]] = [] class ChatCompletionResponseChoice(BaseModel): @@ -103,7 +104,8 @@ async def create_chat_completion(request: ChatCompletionRequest): if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content - + stop_words = request.stop + stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words)))) prev_messages = request.messages[:-1] # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query. # if len(prev_messages) > 0 and prev_messages[0].role == "system": @@ -120,10 +122,18 @@ async def create_chat_completion(request: ChatCompletionRequest): raise HTTPException(status_code=400, detail="Invalid request.") if request.stream: - generate = predict(query, history, request.model) + generate = predict(query, history, request.model, stop_words) return EventSourceResponse(generate, media_type="text/event-stream") - response, _ = model.chat(tokenizer, query, history=history) + if stop_words: + react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words] + response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens) + for stop_ in stop_words: + if response.endswith(stop_): + response = response[:response.find(stop_)] + else: + response, _ = model.chat(tokenizer, query, history=history) + choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), @@ -133,9 +143,9 @@ async def create_chat_completion(request: ChatCompletionRequest): return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") -async def predict(query: str, history: List[List[str]], model_id: str): +async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]): global model, tokenizer - + assert stop_words == [], "in stream format, stop word is output" choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), @@ -145,8 +155,13 @@ async def predict(query: str, history: List[List[str]], model_id: str): yield "{}".format(chunk.model_dump_json(exclude_unset=True)) current_length = 0 + if stop_words: + react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words] + response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens) + else: + response_generator = model.chat_stream(tokenizer, query, history=history) - for new_response in model.chat_stream(tokenizer, query, history): + for new_response in response_generator: if len(new_response) == current_length: continue