Merge pull request #235 from QwenLM/update_demo_int4

update web demo
main
Junyang Lin 1 year ago committed by GitHub
commit 1b6758e157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,7 +34,6 @@ Qwen-7Bは、アリババクラウドが提唱する大規模言語モデルシ
## ニュース
* 2023.8.21 Qwen-7B-Chat 用 Int4 量子化モデル(**Qwen-7B-Chat-Int4**)をリリースしました。メモリコストは低いが、推論速度は向上している。また、ベンチマーク評価において大きな性能劣化はありません。
* 2023.8.3 Qwen-7B と Qwen-7B-Chat を ModelScope と Hugging Face で公開。また、トレーニングの詳細やモデルの性能など、モデルの詳細についてはテクニカルメモを提供しています。
## パフォーマンス

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""A simple web interactive chat demo based on gradio."""
import os
from argparse import ArgumentParser
import gradio as gr
@ -44,17 +44,29 @@ def _load_model_tokenizer(args):
else:
device_map = "auto"
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
if os.path.exists(qconfig_path):
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
use_safetensors=True,
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
)
return model, tokenizer
return model, tokenizer, config
def postprocess(self, y):
@ -103,14 +115,14 @@ def _parse_text(text):
return text
def _launch_demo(args, model, tokenizer):
def _launch_demo(args, model, tokenizer, config):
def predict(_query, _chatbot, _task_history):
print(f"User: {_parse_text(_query)}")
_chatbot.append((_parse_text(_query), ""))
full_response = ""
for response in model.chat_stream(tokenizer, _query, history=_task_history):
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
yield _chatbot
@ -183,9 +195,9 @@ including hate speech, violence, pornography, deception, etc. \
def main():
args = _get_args()
model, tokenizer = _load_model_tokenizer(args)
model, tokenizer, config = _load_model_tokenizer(args)
_launch_demo(args, model, tokenizer)
_launch_demo(args, model, tokenizer, config)
if __name__ == '__main__':

Loading…
Cancel
Save