diff --git a/README.md b/README.md index 3890979..11a7bfb 100644 --- a/README.md +++ b/README.md @@ -371,18 +371,26 @@ you can use the dequantization operation to convert the int8 key/value back to t ## Batch Inference Qwen supports batch inference. With flash-attention enabled, using batch inference can bring a 40% speedup. The example code is shown below: -``` +```python import torch -from tokenization_qwen import QWenTokenizer -from modeling_qwen import QWenLMHeadModel +from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GenerationConfig from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids - -tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') -model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() -model.generation_config = GenerationConfig.from_pretrained('./') -stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer) +tokenizer = AutoTokenizer.from_pretrained( + './', + pad_token='<|extra_0|>', + eos_token='<|endoftext|>', + padding_side='left', + trust_remote_code=True +) +model = AutoModelForCausalLM.from_pretrained( + './', + pad_token_id=tokenizer.pad_token_id, + device_map="auto", + trust_remote_code=True +).eval() +model.generation_config = GenerationConfig.from_pretrained('./', pad_token_id=tokenizer.pad_token_id) all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] batch_raw_text = [] @@ -400,7 +408,6 @@ batch_input_ids = tokenizer(batch_raw_text, padding='longest') batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) batch_out_ids = model.generate( batch_input_ids, - stop_words_ids=stop_words_ids, return_dict_in_generate=False, generation_config=model.generation_config ) @@ -411,7 +418,7 @@ batch_response = [ batch_out_ids[i][padding_lens[i]:], tokenizer, raw_text_len=len(batch_raw_text[i]), - context_length=batch_input_ids[i].size(0), + context_length=(batch_input_ids[i].size(0)-padding_lens[i]), chat_format="chatml", verbose=False, errors='replace' diff --git a/README_CN.md b/README_CN.md index 4cfc404..dd37e02 100644 --- a/README_CN.md +++ b/README_CN.md @@ -359,18 +359,26 @@ model = AutoModelForCausalLM.from_pretrained( ## Batch推理 千问支持batch批量推理。在开启flash-attention的状态下,使用batch推理可以约40%的提速。示例代码如下所示: -``` +```python import torch -from tokenization_qwen import QWenTokenizer -from modeling_qwen import QWenLMHeadModel +from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GenerationConfig from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids - -tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') -model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() -model.generation_config = GenerationConfig.from_pretrained('./') -stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer) +tokenizer = AutoTokenizer.from_pretrained( + './', + pad_token='<|extra_0|>', + eos_token='<|endoftext|>', + padding_side='left', + trust_remote_code=True +) +model = AutoModelForCausalLM.from_pretrained( + './', + pad_token_id=tokenizer.pad_token_id, + device_map="auto", + trust_remote_code=True +).eval() +model.generation_config = GenerationConfig.from_pretrained('./', pad_token_id=tokenizer.pad_token_id) all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] batch_raw_text = [] @@ -388,7 +396,6 @@ batch_input_ids = tokenizer(batch_raw_text, padding='longest') batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) batch_out_ids = model.generate( batch_input_ids, - stop_words_ids=stop_words_ids, return_dict_in_generate=False, generation_config=model.generation_config ) @@ -399,7 +406,7 @@ batch_response = [ batch_out_ids[i][padding_lens[i]:], tokenizer, raw_text_len=len(batch_raw_text[i]), - context_length=batch_input_ids[i].size(0), + context_length=(batch_input_ids[i].size(0)-padding_lens[i]), chat_format="chatml", verbose=False, errors='replace'