From 0a020ea5eed6018c8897d982c6fdebc12a339f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=BC=E6=AC=A3?= Date: Mon, 25 Sep 2023 00:37:00 +0800 Subject: [PATCH] openai_api: support temperature=0 --- openai_api.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/openai_api.py b/openai_api.py index f6683f8..814cff9 100644 --- a/openai_api.py +++ b/openai_api.py @@ -321,7 +321,7 @@ def parse_response(response): # completion mode, not chat mode -def text_complete_last_message(history, stop_words_ids): +def text_complete_last_message(history, stop_words_ids, gen_kwargs): im_start = "<|im_start|>" im_end = "<|im_end|>" prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" @@ -339,7 +339,7 @@ def text_complete_last_message(history, stop_words_ids): stop_words_ids = _stop_words_ids input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device) - output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0] + output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0] output = tokenizer.decode(output, errors="ignore") assert output.startswith(prompt) output = output[len(prompt) :] @@ -352,6 +352,16 @@ def text_complete_last_message(history, stop_words_ids): async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer + gen_kwargs = {} + if request.temperature is not None: + if request.temperature < 0.01: + gen_kwargs['top_k'] = 1 # greedy decoding + else: + # Not recommended. Please tune top_p instead. + gen_kwargs['temperature'] = request.temperature + if request.top_p is not None: + gen_kwargs['top_p'] = request.top_p + stop_words = add_extra_stop_words(request.stop) if request.functions: stop_words = stop_words or [] @@ -366,12 +376,12 @@ async def create_chat_completion(request: ChatCompletionRequest): status_code=400, detail="Invalid request: Function calling is not yet implemented for stream mode.", ) - generate = predict(query, history, request.model, stop_words) + generate = predict(query, history, request.model, stop_words, gen_kwargs) return EventSourceResponse(generate, media_type="text/event-stream") stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None if query is _TEXT_COMPLETION_CMD: - response = text_complete_last_message(history, stop_words_ids=stop_words_ids) + response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs) else: response, _ = model.chat( tokenizer, @@ -379,6 +389,7 @@ async def create_chat_completion(request: ChatCompletionRequest): history=history, stop_words_ids=stop_words_ids, append_history=False, + **gen_kwargs ) print(f"\n{history}\n{query}\n\n{response}\n") response = trim_stop_words(response, stop_words) @@ -396,7 +407,7 @@ async def create_chat_completion(request: ChatCompletionRequest): async def predict( - query: str, history: List[List[str]], model_id: str, stop_words: List[str] + query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict, ): global model, tokenizer choice_data = ChatCompletionResponseStreamChoice( @@ -416,7 +427,7 @@ async def predict( detail="Invalid request: custom stop words are not yet supported for stream mode.", ) response_generator = model.chat_stream( - tokenizer, query, history=history, stop_words_ids=stop_words_ids + tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs ) for new_response in response_generator: if len(new_response) == current_length: