Merge pull request #235 from QwenLM/update_demo_int4

update web demo
main
Junyang Lin 1 year ago committed by GitHub
commit 1b6758e157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,7 +34,6 @@ Qwen-7Bは、アリババクラウドが提唱する大規模言語モデルシ
## ニュース ## ニュース
* 2023.8.21 Qwen-7B-Chat 用 Int4 量子化モデル(**Qwen-7B-Chat-Int4**)をリリースしました。メモリコストは低いが、推論速度は向上している。また、ベンチマーク評価において大きな性能劣化はありません。 * 2023.8.21 Qwen-7B-Chat 用 Int4 量子化モデル(**Qwen-7B-Chat-Int4**)をリリースしました。メモリコストは低いが、推論速度は向上している。また、ベンチマーク評価において大きな性能劣化はありません。
* 2023.8.3 Qwen-7B と Qwen-7B-Chat を ModelScope と Hugging Face で公開。また、トレーニングの詳細やモデルの性能など、モデルの詳細についてはテクニカルメモを提供しています。 * 2023.8.3 Qwen-7B と Qwen-7B-Chat を ModelScope と Hugging Face で公開。また、トレーニングの詳細やモデルの性能など、モデルの詳細についてはテクニカルメモを提供しています。
## パフォーマンス ## パフォーマンス

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
"""A simple web interactive chat demo based on gradio.""" """A simple web interactive chat demo based on gradio."""
import os
from argparse import ArgumentParser from argparse import ArgumentParser
import gradio as gr import gradio as gr
@ -44,17 +44,29 @@ def _load_model_tokenizer(args):
else: else:
device_map = "auto" device_map = "auto"
model = AutoModelForCausalLM.from_pretrained( qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
args.checkpoint_path, if os.path.exists(qconfig_path):
device_map=device_map, from auto_gptq import AutoGPTQForCausalLM
trust_remote_code=True, model = AutoGPTQForCausalLM.from_quantized(
resume_download=True, args.checkpoint_path,
).eval() device_map=device_map,
model.generation_config = GenerationConfig.from_pretrained( trust_remote_code=True,
resume_download=True,
use_safetensors=True,
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True, args.checkpoint_path, trust_remote_code=True, resume_download=True,
) )
return model, tokenizer return model, tokenizer, config
def postprocess(self, y): def postprocess(self, y):
@ -103,14 +115,14 @@ def _parse_text(text):
return text return text
def _launch_demo(args, model, tokenizer): def _launch_demo(args, model, tokenizer, config):
def predict(_query, _chatbot, _task_history): def predict(_query, _chatbot, _task_history):
print(f"User: {_parse_text(_query)}") print(f"User: {_parse_text(_query)}")
_chatbot.append((_parse_text(_query), "")) _chatbot.append((_parse_text(_query), ""))
full_response = "" full_response = ""
for response in model.chat_stream(tokenizer, _query, history=_task_history): for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
_chatbot[-1] = (_parse_text(_query), _parse_text(response)) _chatbot[-1] = (_parse_text(_query), _parse_text(response))
yield _chatbot yield _chatbot
@ -183,9 +195,9 @@ including hate speech, violence, pornography, deception, etc. \
def main(): def main():
args = _get_args() args = _get_args()
model, tokenizer = _load_model_tokenizer(args) model, tokenizer, config = _load_model_tokenizer(args)
_launch_demo(args, model, tokenizer) _launch_demo(args, model, tokenizer, config)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save