Merge pull request #442 from QwenLM/logicwong-patch-2

Update README.md, add batch inference
main
Junyang Lin 1 year ago committed by GitHub
commit 4eee29e790
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -368,6 +368,54 @@ you can use the dequantization operation to convert the int8 key/value back to t
```
<br>
## Batch Inference
Qwen supports batch inference. With flash-attention enabled, using batch inference can bring a 40% speedup. The example code is shown below:
```
import torch
from tokenization_qwen import QWenTokenizer
from modeling_qwen import QWenLMHeadModel
from transformers import GenerationConfig
from qwen_generation_utils import make_context
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left')
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval()
model.generation_config = GenerationConfig.from_pretrained('./')
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
batch_question = []
for q in all_raw_text:
raw_text, _ = make_context(
tokenizer,
q,
system="You are a helpful assistant.",
max_window_size=model.generation_config.max_window_size,
chat_format=model.generation_config.chat_format,
)
batch_question.append(raw_text)
batch_input_ids = tokenizer(batch_question, padding='longest')
print(batch_input_ids)
batch_input_ids1 = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
batch_out_ids = model.generate(
input_ids=batch_input_ids1
,return_dict_in_generate=False
)
batch_response = [tokenizer.decode(o, skip_special_tokens=True) for o in batch_out_ids]
print(batch_response)
response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None)
print(response)
response, _ = model.chat(tokenizer, "今天我想吃点啥,甜甜的,推荐下", history=None)
print(response)
response, _ = model.chat(tokenizer, "我马上迟到了,怎么做才能不迟到", history=None)
print(response)
```
## Finetuning
### Usage

Loading…
Cancel
Save