You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
4.2 KiB
Python

import argparse
import json
from typing import Dict
import logging
import torch
import transformers
from transformers import AutoTokenizer
from transformers.trainer_pt_utils import LabelSmoother
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int,
system_message: str = "You are a helpful assistant."
) -> Dict:
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
im_start = tokenizer.im_start_id
im_end = tokenizer.im_end_id
nl_tokens = tokenizer('\n').input_ids
_system = tokenizer('system').input_ids + nl_tokens
_user = tokenizer('user').input_ids + nl_tokens
_assistant = tokenizer('assistant').input_ids + nl_tokens
# Apply prompt templates
data = []
# input_ids, targets = [], []
for i, source in enumerate(sources):
source = source["conversations"]
if roles[source[0]["from"]] != roles["user"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
_input_id = tokenizer(role).input_ids + nl_tokens + \
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == '<|im_start|>user':
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
elif role == '<|im_start|>assistant':
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
assert len(input_id) == len(target)
input_id = torch.tensor(input_id[:max_len], dtype=torch.int)
target = torch.tensor(target[:max_len], dtype=torch.int)
data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))
return data
if __name__ == "__main__":
parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")
parser.add_argument("--model_name_or_path", type=str, help="model path")
parser.add_argument("--data_path", type=str, help="calibration data path")
parser.add_argument("--out_path", type=str, help="output path of the quantized model")
parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")
parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")
parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")
args = parser.parse_args()
quantize_config = BaseQuantizeConfig(
bits=args.bits,
group_size=args.group_size,
damp_percent=0.01,
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
static_groups=False,
sym=True,
true_sequential=True,
model_name_or_path=None,
model_file_base_name="model"
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eod_id
data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)
model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
model.quantize(data, cache_examples_on_gpu=False)
model.save_quantized(args.out_path, use_safetensors=True)
tokenizer.save_pretrained(args.out_path)