You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
9.9 KiB
Python

from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from typing import Optional, Callable, List, Tuple, Union
import copy
import torch
from transformers import AutoTokenizer
from transformers.generation.logits_process import LogitsProcessorList
from packaging import version
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型而非chat模型进行多轮chat因为您当前在generation_config指定的chat_format并未设置为我们在对话中所支持的"chatml"格式
如果您在直接使用我们从Huggingface提供的模型请确保您在调用model.chat()使用的是"Qwen/Qwen-7B-Chat"模型而非"Qwen/Qwen-7B"预训练模型
"""
IMEND = "<|im_end|>"
ENDOFTEXT = "<|endoftext|>"
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def get_stop_words_ids(chat_format, tokenizer):
if chat_format == "raw":
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
elif chat_format == "chatml":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set()
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return raw_text, context_tokens
class vLLMWrapper:
def __init__(self,
model_dir: str,
trust_remote_code: bool = True,
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.98,
dtype: str = "bfloat16",
**kwargs):
if dtype not in ("bfloat16", "float16", "float32"):
print("now not support {}!".format(dtype))
raise Exception
# build generation_config
self.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
# build tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.tokenizer.eos_token_id = self.generation_config.eos_token_id
self.stop_words_ids = []
from vllm import LLM
import vllm
if version.parse(vllm.__version__) >= version.parse("0.2.2"):
self.__vllm_support_repetition_penalty = True
else:
self.__vllm_support_repetition_penalty = False
quantization = getattr(kwargs, 'quantization', None)
self.model = LLM(model=model_dir,
tokenizer=model_dir,
tensor_parallel_size=tensor_parallel_size,
trust_remote_code=trust_remote_code,
quantization=quantization,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype)
for stop_id in get_stop_words_ids(self.generation_config.chat_format, self.tokenizer):
self.stop_words_ids.extend(stop_id)
self.stop_words_ids.extend([self.generation_config.eos_token_id])
def chat(self,
query: str,
history: Optional[HistoryType],
tokenizer: PreTrainedTokenizer = None,
system: str = "You are a helpful assistant.",
generation_config: Optional[GenerationConfig] = None,
**kwargs):
generation_config = generation_config if generation_config is not None else self.generation_config
tokenizer = self.tokenizer if tokenizer is None else tokenizer
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if not self.__vllm_support_repetition_penalty and generation_config.repetition_penalty != 1:
raise RuntimeError("The installed vLLM doesn't support repetition_penalty, please set ``model.generation_config.repetition_penalty = 1`` or install vllm>=0.2.2")
if history is None:
history = []
else:
# make a copy of the user's input such that is is left untouched
history = copy.deepcopy(history)
extra_stop_words_ids = kwargs.get('stop_words_ids', None)
if extra_stop_words_ids is None:
extra_stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
from vllm.sampling_params import SamplingParams
sampling_kwargs = {
"stop_token_ids": self.stop_words_ids,
"early_stopping": False,
"top_p": generation_config.top_p,
"top_k": -1 if generation_config.top_k == 0 else generation_config.top_k,
"temperature": generation_config.temperature,
"max_tokens": generation_config.max_new_tokens,
"repetition_penalty": generation_config.repetition_penalty
}
if not self.__vllm_support_repetition_penalty:
sampling_kwargs.pop("repetition_penalty")
sampling_params = SamplingParams(**sampling_kwargs)
raw_text, context_tokens = make_context(
self.tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
req_outputs = self.model.generate([query],
sampling_params=sampling_params,
prompt_token_ids=[context_tokens])
req_output = req_outputs[0]
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids = []
req_sample_output_strs = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
if IMEND in output_str:
output_str = output_str[:-len(IMEND)]
if ENDOFTEXT in output_str:
output_str = output_str[:-len(ENDOFTEXT)]
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
assert len(req_sample_output_strs) == 1
response = req_sample_output_strs[0][len(prompt_str):]
history.append((prompt_str, response))
return response, history
if __name__ == '__main__':
model_dir = 'Qwen/Qwen-72B-Chat'
tensor_parallel_size = 2
model = vLLMWrapper(model_dir,
tensor_parallel_size=tensor_parallel_size,
)
response, history = model.chat(query="你好",
history=None)
print(response)
response, history = model.chat(query="给我讲一个年轻人奋斗创业最终取得成功的故事。",
history=history)
print(response)
response, history = model.chat(query="给这个故事起一个标题",
history=history)
print(response)