From 562537e65afa8e3a94a61f74a79e2f948c360dbf Mon Sep 17 00:00:00 2001 From: JustinLin610 Date: Thu, 24 Aug 2023 12:05:19 +0800 Subject: [PATCH] update web demo --- README_JA.md | 1 - web_demo.py | 38 +++++++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/README_JA.md b/README_JA.md index 828229f..7bcc629 100644 --- a/README_JA.md +++ b/README_JA.md @@ -34,7 +34,6 @@ Qwen-7Bは、アリババクラウドが提唱する大規模言語モデルシ ## ニュース * 2023.8.21 Qwen-7B-Chat 用 Int4 量子化モデル(**Qwen-7B-Chat-Int4**)をリリースしました。メモリコストは低いが、推論速度は向上している。また、ベンチマーク評価において大きな性能劣化はありません。 - * 2023.8.3 Qwen-7B と Qwen-7B-Chat を ModelScope と Hugging Face で公開。また、トレーニングの詳細やモデルの性能など、モデルの詳細についてはテクニカルメモを提供しています。 ## パフォーマンス diff --git a/web_demo.py b/web_demo.py index bd25f7f..e94f6be 100755 --- a/web_demo.py +++ b/web_demo.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. """A simple web interactive chat demo based on gradio.""" - +import os from argparse import ArgumentParser import gradio as gr @@ -44,17 +44,29 @@ def _load_model_tokenizer(args): else: device_map = "auto" - model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, - device_map=device_map, - trust_remote_code=True, - resume_download=True, - ).eval() - model.generation_config = GenerationConfig.from_pretrained( + qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json') + if os.path.exists(qconfig_path): + from auto_gptq import AutoGPTQForCausalLM + model = AutoGPTQForCausalLM.from_quantized( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + use_safetensors=True, + ).eval() + else: + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + + config = GenerationConfig.from_pretrained( args.checkpoint_path, trust_remote_code=True, resume_download=True, ) - return model, tokenizer + return model, tokenizer, config def postprocess(self, y): @@ -103,14 +115,14 @@ def _parse_text(text): return text -def _launch_demo(args, model, tokenizer): +def _launch_demo(args, model, tokenizer, config): def predict(_query, _chatbot, _task_history): print(f"User: {_parse_text(_query)}") _chatbot.append((_parse_text(_query), "")) full_response = "" - for response in model.chat_stream(tokenizer, _query, history=_task_history): + for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config): _chatbot[-1] = (_parse_text(_query), _parse_text(response)) yield _chatbot @@ -183,9 +195,9 @@ including hate speech, violence, pornography, deception, etc. \ def main(): args = _get_args() - model, tokenizer = _load_model_tokenizer(args) + model, tokenizer, config = _load_model_tokenizer(args) - _launch_demo(args, model, tokenizer) + _launch_demo(args, model, tokenizer, config) if __name__ == '__main__':