From ea86f6136a0432ed82933fda88a5aea76757ef11 Mon Sep 17 00:00:00 2001 From: "feihu.hf" Date: Mon, 25 Dec 2023 18:57:26 +0800 Subject: [PATCH] add run gptq --- README.md | 48 ++++++++++++++++++++++++++ README_CN.md | 49 +++++++++++++++++++++++++++ run_gptq.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 run_gptq.py diff --git a/README.md b/README.md index ec6c019..8e825c4 100644 --- a/README.md +++ b/README.md @@ -723,6 +723,54 @@ tokenizer.save_pretrained(new_model_directory) Note: For multi-GPU training, you need to specify the proper hyperparameters for distributed training based on your machine. Besides, we advise you to specify your maximum sequence length with the argument `--model_max_length`, based on your consideration of data, memory footprint, and training speed. +### Quantize Fine-tuned Models + +This section applies to full-parameter/LoRA fine-tuned models. (Note: You do not need to quantize the Q-LoRA fine-tuned model because it is already quantized.) +If you use LoRA, please follow the above instructions to merge your model before quantization. + +We recommend using [auto_gptq](https://github.com/PanQiWei/AutoGPTQ) to quantize the finetuned model. + +```bash +pip install auto-gptq optimum +``` + +Note: Currently AutoGPTQ has a bug referred in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/370). Here is a [workaround PR](https://github.com/PanQiWei/AutoGPTQ/pull/495), and you can pull this branch and install from the source. + +First, prepare the calibration data. You can reuse the fine-tuning data, or use other data following the same format. + +Second, run the following script: + +```bash +python run_gptq.py \ + --model_name_or_path $YOUR_LORA_MODEL_PATH \ + --data_path $DATA \ + --out_path $OUTPUT_PATH \ + --bits 4 # 4 for int4; 8 for int8 +``` + +This step requires GPUs and may costs a few hours according to your data size and model size. + +Then, copy all `*.py`, `*.cu`, `*.cpp` files and `generation_config.json` to the output path. And we recommend you to overwrite `config.json` by copying the file from the coresponding official quantized model +(for example, if you are fine-tuning `Qwen-7B-Chat` and use `--bits 4`, you can find the `config.json` from [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json)). +You should also rename the ``gptq.safetensors`` into ``model.safetensors``. + +Finally, test the model by the same method to load the official quantized model. For example, + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + +tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True) + +model = AutoModelForCausalLM.from_pretrained( + "/path/to/your/model", + device_map="auto", + trust_remote_code=True +).eval() + +response, history = model.chat(tokenizer, "你好", history=None) +print(response) +``` ### Profiling of Memory and Speed We profile the GPU memory and training speed of both LoRA (LoRA (emb) refers to training the embedding and output layer, while LoRA has no trainable embedding and output layer) and Q-LoRA in the setup of single-GPU training. In this test, we experiment on a single A100-SXM4-80G GPU, and we use CUDA 11.8 and Pytorch 2.0. Flash attention 2 is applied. We uniformly use a batch size of 1 and gradient accumulation of 8. We profile the memory (GB) and speed (s/iter) of inputs of different lengths, namely 256, 512, 1024, 2048, 4096, and 8192. We also report the statistics of full-parameter finetuning with Qwen-7B on 2 A100 GPUs. We only report the statistics of 256, 512, and 1024 tokens due to the limitation of GPU memory. diff --git a/README_CN.md b/README_CN.md index 7ca7bfc..f53585d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -713,6 +713,55 @@ tokenizer.save_pretrained(new_model_directory) 注意:分布式训练需要根据你的需求和机器指定正确的分布式训练超参数。此外,你需要根据你的数据、显存情况和训练速度预期,使用`--model_max_length`设定你的数据长度。 +### 量化微调后模型 + +这一小节用于量化全参/LoRA微调后的模型。(注意:你不需要量化Q-LoRA模型因为它本身就是量化过的。) +如果你需要量化LoRA微调后的模型,请先根据上方说明去合并你的模型权重。 + +我们推荐使用[auto_gptq](https://github.com/PanQiWei/AutoGPTQ)去量化你的模型。 + +```bash +pip install auto-gptq optimum +``` + +注意: 当前AutoGPTQ有个bug,可以在该[issue](https://github.com/PanQiWei/AutoGPTQ/issues/370)查看。这里有个[修改PR](https://github.com/PanQiWei/AutoGPTQ/pull/495),你可以使用该分支从代码进行安装。 + +首先,准备校准集。你可以重用微调你的数据,或者按照微调相同的方式准备其他数据。 + +第二步,运行以下命令: + +```bash +python run_gptq.py \ + --model_name_or_path $YOUR_LORA_MODEL_PATH \ + --data_path $DATA \ + --out_path $OUTPUT_PATH \ + --bits 4 # 4 for int4; 8 for int8 +``` + +这一步需要使用GPU,根据你的校准集大小和模型大小,可能会消耗数个小时。 + +接下来, 将原模型中所有 `*.py`, `*.cu`, `*.cpp` 文件和 `generation_config.json` 文件复制到输出模型目录下。同时,使用官方对应版本的量化模型的 `config.json` 文件覆盖输出模型目录下的文件 +(例如, 如果你微调了 `Qwen-7B-Chat`和`--bits 4`, 那么你可以从 [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json) 仓库中找到对应的`config.json` )。 +并且,你需要将 ``gptq.safetensors`` 重命名为 ``model.safetensors``。 + +最后,像官方量化模型一样测试你的模型。例如: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + +tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True) + +model = AutoModelForCausalLM.from_pretrained( + "/path/to/your/model", + device_map="auto", + trust_remote_code=True +).eval() + +response, history = model.chat(tokenizer, "你好", history=None) +print(response) +``` + ### 显存占用及训练速度 下面记录7B和14B模型在单GPU使用LoRA(LoRA (emb)指的是embedding和输出层参与训练,而LoRA则不优化这部分参数)和QLoRA时处理不同长度输入的显存占用和训练速度的情况。本次评测运行于单张A100-SXM4-80G GPU,使用CUDA 11.8和Pytorch 2.0,并使用了flash attention 2。我们统一使用batch size为1,gradient accumulation为8的训练配置,记录输入长度分别为256、512、1024、2048、4096和8192的显存占用(GB)和训练速度(s/iter)。我们还使用2张A100测了Qwen-7B的全参数微调。受限于显存大小,我们仅测试了256、512和1024token的性能。 diff --git a/run_gptq.py b/run_gptq.py new file mode 100644 index 0000000..5609fe8 --- /dev/null +++ b/run_gptq.py @@ -0,0 +1,96 @@ +import argparse +import json +from typing import Dict +import logging + +import torch +import transformers +from transformers import AutoTokenizer +from transformers.trainer_pt_utils import LabelSmoother +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, + max_len: int, + system_message: str = "You are a helpful assistant." +) -> Dict: + roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} + + im_start = tokenizer.im_start_id + im_end = tokenizer.im_end_id + nl_tokens = tokenizer('\n').input_ids + _system = tokenizer('system').input_ids + nl_tokens + _user = tokenizer('user').input_ids + nl_tokens + _assistant = tokenizer('assistant').input_ids + nl_tokens + + # Apply prompt templates + data = [] + # input_ids, targets = [], [] + for i, source in enumerate(sources): + source = source["conversations"] + if roles[source[0]["from"]] != roles["user"]: + source = source[1:] + + input_id, target = [], [] + system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens + input_id += system + target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens + assert len(input_id) == len(target) + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + _input_id = tokenizer(role).input_ids + nl_tokens + \ + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens + input_id += _input_id + if role == '<|im_start|>user': + _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens + elif role == '<|im_start|>assistant': + _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ + _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens + else: + raise NotImplementedError + target += _target + assert len(input_id) == len(target) + input_id = torch.tensor(input_id[:max_len], dtype=torch.int) + target = torch.tensor(target[:max_len], dtype=torch.int) + data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id))) + + return data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ") + parser.add_argument("--model_name_or_path", type=str, help="model path") + parser.add_argument("--data_path", type=str, help="calibration data path") + parser.add_argument("--out_path", type=str, help="output path of the quantized model") + parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data") + parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.") + parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model") + args = parser.parse_args() + + quantize_config = BaseQuantizeConfig( + bits=args.bits, + group_size=args.group_size, + damp_percent=0.01, + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad + static_groups=False, + sym=True, + true_sequential=True, + model_name_or_path=None, + model_file_base_name="model" + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) + tokenizer.pad_token_id = tokenizer.eod_id + data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len) + + model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True) + + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) + model.quantize(data, cache_examples_on_gpu=False) + + model.save_quantized(args.out_path, use_safetensors=True) + tokenizer.save_pretrained(args.out_path)