update readme about batch inference

main
yangapku 1 year ago
parent 9d1d0be363
commit 78352b5a79

@ -371,18 +371,26 @@ you can use the dequantization operation to convert the int8 key/value back to t
## Batch Inference ## Batch Inference
Qwen supports batch inference. With flash-attention enabled, using batch inference can bring a 40% speedup. The example code is shown below: Qwen supports batch inference. With flash-attention enabled, using batch inference can bring a 40% speedup. The example code is shown below:
``` ```python
import torch import torch
from tokenization_qwen import QWenTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from modeling_qwen import QWenLMHeadModel
from transformers import GenerationConfig from transformers import GenerationConfig
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') './',
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() pad_token='<|extra_0|>',
model.generation_config = GenerationConfig.from_pretrained('./') eos_token='<|endoftext|>',
stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer) padding_side='left',
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
'./',
pad_token_id=tokenizer.pad_token_id,
device_map="auto",
trust_remote_code=True
).eval()
model.generation_config = GenerationConfig.from_pretrained('./', pad_token_id=tokenizer.pad_token_id)
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
batch_raw_text = [] batch_raw_text = []
@ -400,7 +408,6 @@ batch_input_ids = tokenizer(batch_raw_text, padding='longest')
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
batch_out_ids = model.generate( batch_out_ids = model.generate(
batch_input_ids, batch_input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False, return_dict_in_generate=False,
generation_config=model.generation_config generation_config=model.generation_config
) )
@ -411,7 +418,7 @@ batch_response = [
batch_out_ids[i][padding_lens[i]:], batch_out_ids[i][padding_lens[i]:],
tokenizer, tokenizer,
raw_text_len=len(batch_raw_text[i]), raw_text_len=len(batch_raw_text[i]),
context_length=batch_input_ids[i].size(0), context_length=(batch_input_ids[i].size(0)-padding_lens[i]),
chat_format="chatml", chat_format="chatml",
verbose=False, verbose=False,
errors='replace' errors='replace'

@ -359,18 +359,26 @@ model = AutoModelForCausalLM.from_pretrained(
## Batch推理 ## Batch推理
千问支持batch批量推理。在开启flash-attention的状态下使用batch推理可以约40%的提速。示例代码如下所示: 千问支持batch批量推理。在开启flash-attention的状态下使用batch推理可以约40%的提速。示例代码如下所示:
``` ```python
import torch import torch
from tokenization_qwen import QWenTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from modeling_qwen import QWenLMHeadModel
from transformers import GenerationConfig from transformers import GenerationConfig
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left') './',
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval() pad_token='<|extra_0|>',
model.generation_config = GenerationConfig.from_pretrained('./') eos_token='<|endoftext|>',
stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer) padding_side='left',
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
'./',
pad_token_id=tokenizer.pad_token_id,
device_map="auto",
trust_remote_code=True
).eval()
model.generation_config = GenerationConfig.from_pretrained('./', pad_token_id=tokenizer.pad_token_id)
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"] all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
batch_raw_text = [] batch_raw_text = []
@ -388,7 +396,6 @@ batch_input_ids = tokenizer(batch_raw_text, padding='longest')
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device) batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
batch_out_ids = model.generate( batch_out_ids = model.generate(
batch_input_ids, batch_input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False, return_dict_in_generate=False,
generation_config=model.generation_config generation_config=model.generation_config
) )
@ -399,7 +406,7 @@ batch_response = [
batch_out_ids[i][padding_lens[i]:], batch_out_ids[i][padding_lens[i]:],
tokenizer, tokenizer,
raw_text_len=len(batch_raw_text[i]), raw_text_len=len(batch_raw_text[i]),
context_length=batch_input_ids[i].size(0), context_length=(batch_input_ids[i].size(0)-padding_lens[i]),
chat_format="chatml", chat_format="chatml",
verbose=False, verbose=False,
errors='replace' errors='replace'

Loading…
Cancel
Save