Update README.md, update batch infer

main
Wang Peng 1 year ago committed by GitHub
parent 96a5789238
commit c73a065849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -376,15 +376,16 @@ import torch
from tokenization_qwen import QWenTokenizer from tokenization_qwen import QWenTokenizer
from modeling_qwen import QWenLMHeadModel from modeling_qwen import QWenLMHeadModel
from transformers import GenerationConfig from transformers import GenerationConfig
from qwen_generation_utils import make_context from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left')
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval()
model.generation_config = GenerationConfig.from_pretrained('./') model.generation_config = GenerationConfig.from_pretrained('./')
stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer)
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
batch_question = [] batch_raw_text = []
for q in all_raw_text: for q in all_raw_text:
raw_text, _ = make_context( raw_text, _ = make_context(
tokenizer, tokenizer,
@ -393,17 +394,29 @@ for q in all_raw_text:
max_window_size=model.generation_config.max_window_size, max_window_size=model.generation_config.max_window_size,
chat_format=model.generation_config.chat_format, chat_format=model.generation_config.chat_format,
) )
batch_question.append(raw_text) batch_raw_text.append(raw_text)
batch_input_ids = tokenizer(batch_question, padding='longest') batch_input_ids = tokenizer(batch_raw_text, padding='longest')
print(batch_input_ids) batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
batch_input_ids1 = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
batch_out_ids = model.generate( batch_out_ids = model.generate(
input_ids=batch_input_ids1 batch_input_ids,
,return_dict_in_generate=False stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=model.generation_config
) )
batch_response = [tokenizer.decode(o, skip_special_tokens=True) for o in batch_out_ids] padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
batch_response = [
decode_tokens(
batch_out_ids[i][padding_lens[i]:],
tokenizer,
raw_text_len=len(batch_raw_text[i]),
context_length=batch_input_ids[i].size(0),
chat_format="chatml",
verbose=False,
errors='replace'
) for i in range(len(all_raw_text))
]
print(batch_response) print(batch_response)
response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None) response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None)

Loading…
Cancel
Save