|
|
|
@ -321,7 +321,7 @@ def parse_response(response):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# completion mode, not chat mode
|
|
|
|
|
def text_complete_last_message(history, stop_words_ids):
|
|
|
|
|
def text_complete_last_message(history, stop_words_ids, gen_kwargs):
|
|
|
|
|
im_start = "<|im_start|>"
|
|
|
|
|
im_end = "<|im_end|>"
|
|
|
|
|
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
|
|
|
|
@ -339,7 +339,7 @@ def text_complete_last_message(history, stop_words_ids):
|
|
|
|
|
stop_words_ids = _stop_words_ids
|
|
|
|
|
|
|
|
|
|
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
|
|
|
|
|
output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0]
|
|
|
|
|
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
|
|
|
|
|
output = tokenizer.decode(output, errors="ignore")
|
|
|
|
|
assert output.startswith(prompt)
|
|
|
|
|
output = output[len(prompt) :]
|
|
|
|
@ -352,6 +352,16 @@ def text_complete_last_message(history, stop_words_ids):
|
|
|
|
|
async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
global model, tokenizer
|
|
|
|
|
|
|
|
|
|
gen_kwargs = {}
|
|
|
|
|
if request.temperature is not None:
|
|
|
|
|
if request.temperature < 0.01:
|
|
|
|
|
gen_kwargs['top_k'] = 1 # greedy decoding
|
|
|
|
|
else:
|
|
|
|
|
# Not recommended. Please tune top_p instead.
|
|
|
|
|
gen_kwargs['temperature'] = request.temperature
|
|
|
|
|
if request.top_p is not None:
|
|
|
|
|
gen_kwargs['top_p'] = request.top_p
|
|
|
|
|
|
|
|
|
|
stop_words = add_extra_stop_words(request.stop)
|
|
|
|
|
if request.functions:
|
|
|
|
|
stop_words = stop_words or []
|
|
|
|
@ -366,12 +376,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="Invalid request: Function calling is not yet implemented for stream mode.",
|
|
|
|
|
)
|
|
|
|
|
generate = predict(query, history, request.model, stop_words)
|
|
|
|
|
generate = predict(query, history, request.model, stop_words, gen_kwargs)
|
|
|
|
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
|
|
|
|
|
if query is _TEXT_COMPLETION_CMD:
|
|
|
|
|
response = text_complete_last_message(history, stop_words_ids=stop_words_ids)
|
|
|
|
|
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
|
|
|
|
|
else:
|
|
|
|
|
response, _ = model.chat(
|
|
|
|
|
tokenizer,
|
|
|
|
@ -379,6 +389,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
history=history,
|
|
|
|
|
stop_words_ids=stop_words_ids,
|
|
|
|
|
append_history=False,
|
|
|
|
|
**gen_kwargs
|
|
|
|
|
)
|
|
|
|
|
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
|
|
|
|
response = trim_stop_words(response, stop_words)
|
|
|
|
@ -396,7 +407,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def predict(
|
|
|
|
|
query: str, history: List[List[str]], model_id: str, stop_words: List[str]
|
|
|
|
|
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
|
|
|
|
|
):
|
|
|
|
|
global model, tokenizer
|
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
@ -416,7 +427,7 @@ async def predict(
|
|
|
|
|
detail="Invalid request: custom stop words are not yet supported for stream mode.",
|
|
|
|
|
)
|
|
|
|
|
response_generator = model.chat_stream(
|
|
|
|
|
tokenizer, query, history=history, stop_words_ids=stop_words_ids
|
|
|
|
|
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
|
|
|
|
|
)
|
|
|
|
|
for new_response in response_generator:
|
|
|
|
|
if len(new_response) == current_length:
|
|
|
|
|