|
|
|
@ -4,7 +4,7 @@
|
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
"""A simple web interactive chat demo based on gradio."""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
|
|
|
|
import gradio as gr
|
|
|
|
@ -44,17 +44,29 @@ def _load_model_tokenizer(args):
|
|
|
|
|
else:
|
|
|
|
|
device_map = "auto"
|
|
|
|
|
|
|
|
|
|
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
|
|
|
|
|
if os.path.exists(qconfig_path):
|
|
|
|
|
from auto_gptq import AutoGPTQForCausalLM
|
|
|
|
|
model = AutoGPTQForCausalLM.from_quantized(
|
|
|
|
|
args.checkpoint_path,
|
|
|
|
|
device_map=device_map,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
resume_download=True,
|
|
|
|
|
use_safetensors=True,
|
|
|
|
|
).eval()
|
|
|
|
|
else:
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
args.checkpoint_path,
|
|
|
|
|
device_map=device_map,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
resume_download=True,
|
|
|
|
|
).eval()
|
|
|
|
|
model.generation_config = GenerationConfig.from_pretrained(
|
|
|
|
|
|
|
|
|
|
config = GenerationConfig.from_pretrained(
|
|
|
|
|
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
return model, tokenizer, config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self, y):
|
|
|
|
@ -103,14 +115,14 @@ def _parse_text(text):
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _launch_demo(args, model, tokenizer):
|
|
|
|
|
def _launch_demo(args, model, tokenizer, config):
|
|
|
|
|
|
|
|
|
|
def predict(_query, _chatbot, _task_history):
|
|
|
|
|
print(f"User: {_parse_text(_query)}")
|
|
|
|
|
_chatbot.append((_parse_text(_query), ""))
|
|
|
|
|
full_response = ""
|
|
|
|
|
|
|
|
|
|
for response in model.chat_stream(tokenizer, _query, history=_task_history):
|
|
|
|
|
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
|
|
|
|
|
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
|
|
|
|
|
|
|
|
|
|
yield _chatbot
|
|
|
|
@ -183,9 +195,9 @@ including hate speech, violence, pornography, deception, etc. \
|
|
|
|
|
def main():
|
|
|
|
|
args = _get_args()
|
|
|
|
|
|
|
|
|
|
model, tokenizer = _load_model_tokenizer(args)
|
|
|
|
|
model, tokenizer, config = _load_model_tokenizer(args)
|
|
|
|
|
|
|
|
|
|
_launch_demo(args, model, tokenizer)
|
|
|
|
|
_launch_demo(args, model, tokenizer, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|