You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
4.2 KiB
Python
97 lines
4.2 KiB
Python
1 year ago
|
import argparse
|
||
|
import json
|
||
|
from typing import Dict
|
||
|
import logging
|
||
|
|
||
|
import torch
|
||
|
import transformers
|
||
|
from transformers import AutoTokenizer
|
||
|
from transformers.trainer_pt_utils import LabelSmoother
|
||
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||
|
|
||
|
def preprocess(
|
||
|
sources,
|
||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||
|
max_len: int,
|
||
|
system_message: str = "You are a helpful assistant."
|
||
|
) -> Dict:
|
||
|
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
|
||
|
|
||
|
im_start = tokenizer.im_start_id
|
||
|
im_end = tokenizer.im_end_id
|
||
|
nl_tokens = tokenizer('\n').input_ids
|
||
|
_system = tokenizer('system').input_ids + nl_tokens
|
||
|
_user = tokenizer('user').input_ids + nl_tokens
|
||
|
_assistant = tokenizer('assistant').input_ids + nl_tokens
|
||
|
|
||
|
# Apply prompt templates
|
||
|
data = []
|
||
|
# input_ids, targets = [], []
|
||
|
for i, source in enumerate(sources):
|
||
|
source = source["conversations"]
|
||
|
if roles[source[0]["from"]] != roles["user"]:
|
||
|
source = source[1:]
|
||
|
|
||
|
input_id, target = [], []
|
||
|
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
|
||
|
input_id += system
|
||
|
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
|
||
|
assert len(input_id) == len(target)
|
||
|
for j, sentence in enumerate(source):
|
||
|
role = roles[sentence["from"]]
|
||
|
_input_id = tokenizer(role).input_ids + nl_tokens + \
|
||
|
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
|
||
|
input_id += _input_id
|
||
|
if role == '<|im_start|>user':
|
||
|
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
|
||
|
elif role == '<|im_start|>assistant':
|
||
|
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
|
||
|
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
target += _target
|
||
|
assert len(input_id) == len(target)
|
||
|
input_id = torch.tensor(input_id[:max_len], dtype=torch.int)
|
||
|
target = torch.tensor(target[:max_len], dtype=torch.int)
|
||
|
data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))
|
||
|
|
||
|
return data
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")
|
||
|
parser.add_argument("--model_name_or_path", type=str, help="model path")
|
||
|
parser.add_argument("--data_path", type=str, help="calibration data path")
|
||
|
parser.add_argument("--out_path", type=str, help="output path of the quantized model")
|
||
|
parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")
|
||
|
parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")
|
||
|
parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
quantize_config = BaseQuantizeConfig(
|
||
|
bits=args.bits,
|
||
|
group_size=args.group_size,
|
||
|
damp_percent=0.01,
|
||
|
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
|
||
|
static_groups=False,
|
||
|
sym=True,
|
||
|
true_sequential=True,
|
||
|
model_name_or_path=None,
|
||
|
model_file_base_name="model"
|
||
|
)
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
|
||
|
tokenizer.pad_token_id = tokenizer.eod_id
|
||
|
data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)
|
||
|
|
||
|
model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)
|
||
|
|
||
|
logging.basicConfig(
|
||
|
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||
|
)
|
||
|
model.quantize(data, cache_examples_on_gpu=False)
|
||
|
|
||
|
model.save_quantized(args.out_path, use_safetensors=True)
|
||
|
tokenizer.save_pretrained(args.out_path)
|