diff --git a/README_CN.md b/README_CN.md index de3a656..c98bca2 100644 --- a/README_CN.md +++ b/README_CN.md @@ -335,6 +335,53 @@ model = AutoModelForCausalLM.from_pretrained( ```
+## Batch推理 +千问支持batch批量推理。在开启flash-attention的状态下,使用batch推理可以约40%的提速。示例代码如下所示: +``` +import torch +from tokenization_qwen import QWenTokenizer +from modeling_qwen import QWenLMHeadModel +from transformers import GenerationConfig +from qwen_generation_utils import make_context + + +tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') +model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() +model.generation_config = GenerationConfig.from_pretrained('./') + +all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] +batch_question = [] +for q in all_raw_text: + raw_text, _ = make_context( + tokenizer, + q, + system="You are a helpful assistant.", + max_window_size=model.generation_config.max_window_size, + chat_format=model.generation_config.chat_format, + ) + batch_question.append(raw_text) + +batch_input_ids = tokenizer(batch_question, padding='longest') +print(batch_input_ids) + +batch_input_ids1 = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) +batch_out_ids = model.generate( + input_ids=batch_input_ids1 + ,return_dict_in_generate=False +) +batch_response = [tokenizer.decode(o, skip_special_tokens=True) for o in batch_out_ids] +print(batch_response) + +response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None) +print(response) + +response, _ = model.chat(tokenizer, "今天我想吃点啥,甜甜的,推荐下", history=None) +print(response) + +response, _ = model.chat(tokenizer, "我马上迟到了,怎么做才能不迟到", history=None) +print(response) +``` + ## 微调 ### 使用方法