From a7be70ab412e8141fc60a35213423483130eb54e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 31 Jan 2024 19:47:27 +0800 Subject: [PATCH] update openai_api: support stop words for streaming chat --- openai_api.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/openai_api.py b/openai_api.py index fd8e635..0e127c6 100644 --- a/openai_api.py +++ b/openai_api.py @@ -483,20 +483,19 @@ async def predict( current_length = 0 stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None - if stop_words: - # TODO: It's a little bit tricky to trim stop words in the stream mode. - raise HTTPException( - status_code=400, - detail= - 'Invalid request: custom stop words are not yet supported for stream mode.', - ) + + delay_token_num = max([len(x) for x in stop_words]) response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=stop_words_ids, system=system, **gen_kwargs) - for new_response in response_generator: + for _new_response in response_generator: + if len(_new_response) <= delay_token_num: + continue + new_response = _new_response[:-delay_token_num] + if len(new_response) == current_length: continue @@ -509,6 +508,18 @@ async def predict( choices=[choice_data], object='chat.completion.chunk') yield '{}'.format(_dump_json(chunk, exclude_unset=True)) + + if current_length != len(_new_response): + # Determine whether to print the delay tokens + delayed_text = _new_response[current_length:] + new_text = trim_stop_words(delayed_text, stop_words) + if len(new_text) > 0: + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(content=new_text), finish_reason=None) + chunk = ChatCompletionResponse(model=model_id, + choices=[choice_data], + object='chat.completion.chunk') + yield '{}'.format(_dump_json(chunk, exclude_unset=True)) choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(),