openai_api: support temperature=0

main
兼欣 2 years ago
parent fb3180d8f0
commit 0a020ea5ee

@ -321,7 +321,7 @@ def parse_response(response):
# completion mode, not chat mode
def text_complete_last_message(history, stop_words_ids):
def text_complete_last_message(history, stop_words_ids, gen_kwargs):
im_start = "<|im_start|>"
im_end = "<|im_end|>"
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
@ -339,7 +339,7 @@ def text_complete_last_message(history, stop_words_ids):
stop_words_ids = _stop_words_ids
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0]
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
output = tokenizer.decode(output, errors="ignore")
assert output.startswith(prompt)
output = output[len(prompt) :]
@ -352,6 +352,16 @@ def text_complete_last_message(history, stop_words_ids):
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
gen_kwargs = {}
if request.temperature is not None:
if request.temperature < 0.01:
gen_kwargs['top_k'] = 1 # greedy decoding
else:
# Not recommended. Please tune top_p instead.
gen_kwargs['temperature'] = request.temperature
if request.top_p is not None:
gen_kwargs['top_p'] = request.top_p
stop_words = add_extra_stop_words(request.stop)
if request.functions:
stop_words = stop_words or []
@ -366,12 +376,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
status_code=400,
detail="Invalid request: Function calling is not yet implemented for stream mode.",
)
generate = predict(query, history, request.model, stop_words)
generate = predict(query, history, request.model, stop_words, gen_kwargs)
return EventSourceResponse(generate, media_type="text/event-stream")
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
if query is _TEXT_COMPLETION_CMD:
response = text_complete_last_message(history, stop_words_ids=stop_words_ids)
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
else:
response, _ = model.chat(
tokenizer,
@ -379,6 +389,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
history=history,
stop_words_ids=stop_words_ids,
append_history=False,
**gen_kwargs
)
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
response = trim_stop_words(response, stop_words)
@ -396,7 +407,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
async def predict(
query: str, history: List[List[str]], model_id: str, stop_words: List[str]
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
@ -416,7 +427,7 @@ async def predict(
detail="Invalid request: custom stop words are not yet supported for stream mode.",
)
response_generator = model.chat_stream(
tokenizer, query, history=history, stop_words_ids=stop_words_ids
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
)
for new_response in response_generator:
if len(new_response) == current_length:

Loading…
Cancel
Save