From c73a0658490dd61d44d32afa75f3497e9b64e6b1 Mon Sep 17 00:00:00 2001 From: Wang Peng <798960736@qq.com> Date: Thu, 12 Oct 2023 13:29:06 +0800 Subject: [PATCH] Update README.md, update batch infer --- README.md | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 2b9256d..3890979 100644 --- a/README.md +++ b/README.md @@ -376,15 +376,16 @@ import torch from tokenization_qwen import QWenTokenizer from modeling_qwen import QWenLMHeadModel from transformers import GenerationConfig -from qwen_generation_utils import make_context +from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() model.generation_config = GenerationConfig.from_pretrained('./') +stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer) all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] -batch_question = [] +batch_raw_text = [] for q in all_raw_text: raw_text, _ = make_context( tokenizer, @@ -393,17 +394,29 @@ for q in all_raw_text: max_window_size=model.generation_config.max_window_size, chat_format=model.generation_config.chat_format, ) - batch_question.append(raw_text) + batch_raw_text.append(raw_text) -batch_input_ids = tokenizer(batch_question, padding='longest') -print(batch_input_ids) - -batch_input_ids1 = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) +batch_input_ids = tokenizer(batch_raw_text, padding='longest') +batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) batch_out_ids = model.generate( - input_ids=batch_input_ids1 - ,return_dict_in_generate=False + batch_input_ids, + stop_words_ids=stop_words_ids, + return_dict_in_generate=False, + generation_config=model.generation_config ) -batch_response = [tokenizer.decode(o, skip_special_tokens=True) for o in batch_out_ids] +padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))] + +batch_response = [ + decode_tokens( + batch_out_ids[i][padding_lens[i]:], + tokenizer, + raw_text_len=len(batch_raw_text[i]), + context_length=batch_input_ids[i].size(0), + chat_format="chatml", + verbose=False, + errors='replace' + ) for i in range(len(all_raw_text)) +] print(batch_response) response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None)