Merge pull request #1038 from tuhahaha/main

update openai_api: support stop words for streaming chat
main
Jianxin Ma 12 months ago committed by GitHub
commit b792917925
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -483,20 +483,19 @@ async def predict(
current_length = 0 current_length = 0
stop_words_ids = [tokenizer.encode(s) stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None for s in stop_words] if stop_words else None
if stop_words:
# TODO: It's a little bit tricky to trim stop words in the stream mode. delay_token_num = max([len(x) for x in stop_words])
raise HTTPException(
status_code=400,
detail=
'Invalid request: custom stop words are not yet supported for stream mode.',
)
response_generator = model.chat_stream(tokenizer, response_generator = model.chat_stream(tokenizer,
query, query,
history=history, history=history,
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
system=system, system=system,
**gen_kwargs) **gen_kwargs)
for new_response in response_generator: for _new_response in response_generator:
if len(_new_response) <= delay_token_num:
continue
new_response = _new_response[:-delay_token_num]
if len(new_response) == current_length: if len(new_response) == current_length:
continue continue
@ -509,6 +508,18 @@ async def predict(
choices=[choice_data], choices=[choice_data],
object='chat.completion.chunk') object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True)) yield '{}'.format(_dump_json(chunk, exclude_unset=True))
if current_length != len(_new_response):
# Determine whether to print the delay tokens
delayed_text = _new_response[current_length:]
new_text = trim_stop_words(delayed_text, stop_words)
if len(new_text) > 0:
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(index=0, choice_data = ChatCompletionResponseStreamChoice(index=0,
delta=DeltaMessage(), delta=DeltaMessage(),

Loading…
Cancel
Save