add run gptq

main
feihu.hf 1 year ago committed by Ren Xuancheng
parent 65c73034c3
commit ea86f6136a

@ -723,6 +723,54 @@ tokenizer.save_pretrained(new_model_directory)
Note: For multi-GPU training, you need to specify the proper hyperparameters for distributed training based on your machine. Besides, we advise you to specify your maximum sequence length with the argument `--model_max_length`, based on your consideration of data, memory footprint, and training speed.
### Quantize Fine-tuned Models
This section applies to full-parameter/LoRA fine-tuned models. (Note: You do not need to quantize the Q-LoRA fine-tuned model because it is already quantized.)
If you use LoRA, please follow the above instructions to merge your model before quantization.
We recommend using [auto_gptq](https://github.com/PanQiWei/AutoGPTQ) to quantize the finetuned model.
```bash
pip install auto-gptq optimum
```
Note: Currently AutoGPTQ has a bug referred in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/370). Here is a [workaround PR](https://github.com/PanQiWei/AutoGPTQ/pull/495), and you can pull this branch and install from the source.
First, prepare the calibration data. You can reuse the fine-tuning data, or use other data following the same format.
Second, run the following script:
```bash
python run_gptq.py \
--model_name_or_path $YOUR_LORA_MODEL_PATH \
--data_path $DATA \
--out_path $OUTPUT_PATH \
--bits 4 # 4 for int4; 8 for int8
```
This step requires GPUs and may costs a few hours according to your data size and model size.
Then, copy all `*.py`, `*.cu`, `*.cpp` files and `generation_config.json` to the output path. And we recommend you to overwrite `config.json` by copying the file from the coresponding official quantized model
(for example, if you are fine-tuning `Qwen-7B-Chat` and use `--bits 4`, you can find the `config.json` from [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json)).
You should also rename the ``gptq.safetensors`` into ``model.safetensors``.
Finally, test the model by the same method to load the official quantized model. For example,
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"/path/to/your/model",
device_map="auto",
trust_remote_code=True
).eval()
response, history = model.chat(tokenizer, "你好", history=None)
print(response)
```
### Profiling of Memory and Speed
We profile the GPU memory and training speed of both LoRA (LoRA (emb) refers to training the embedding and output layer, while LoRA has no trainable embedding and output layer) and Q-LoRA in the setup of single-GPU training. In this test, we experiment on a single A100-SXM4-80G GPU, and we use CUDA 11.8 and Pytorch 2.0. Flash attention 2 is applied. We uniformly use a batch size of 1 and gradient accumulation of 8. We profile the memory (GB) and speed (s/iter) of inputs of different lengths, namely 256, 512, 1024, 2048, 4096, and 8192. We also report the statistics of full-parameter finetuning with Qwen-7B on 2 A100 GPUs. We only report the statistics of 256, 512, and 1024 tokens due to the limitation of GPU memory.

@ -713,6 +713,55 @@ tokenizer.save_pretrained(new_model_directory)
注意:分布式训练需要根据你的需求和机器指定正确的分布式训练超参数。此外,你需要根据你的数据、显存情况和训练速度预期,使用`--model_max_length`设定你的数据长度。
### 量化微调后模型
这一小节用于量化全参/LoRA微调后的模型。注意你不需要量化Q-LoRA模型因为它本身就是量化过的。
如果你需要量化LoRA微调后的模型请先根据上方说明去合并你的模型权重。
我们推荐使用[auto_gptq](https://github.com/PanQiWei/AutoGPTQ)去量化你的模型。
```bash
pip install auto-gptq optimum
```
注意: 当前AutoGPTQ有个bug可以在该[issue](https://github.com/PanQiWei/AutoGPTQ/issues/370)查看。这里有个[修改PR](https://github.com/PanQiWei/AutoGPTQ/pull/495),你可以使用该分支从代码进行安装。
首先,准备校准集。你可以重用微调你的数据,或者按照微调相同的方式准备其他数据。
第二步,运行以下命令:
```bash
python run_gptq.py \
--model_name_or_path $YOUR_LORA_MODEL_PATH \
--data_path $DATA \
--out_path $OUTPUT_PATH \
--bits 4 # 4 for int4; 8 for int8
```
这一步需要使用GPU根据你的校准集大小和模型大小可能会消耗数个小时。
接下来, 将原模型中所有 `*.py`, `*.cu`, `*.cpp` 文件和 `generation_config.json` 文件复制到输出模型目录下。同时,使用官方对应版本的量化模型的 `config.json` 文件覆盖输出模型目录下的文件
(例如, 如果你微调了 `Qwen-7B-Chat`和`--bits 4`, 那么你可以从 [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json) 仓库中找到对应的`config.json` )。
并且,你需要将 ``gptq.safetensors`` 重命名为 ``model.safetensors``。
最后,像官方量化模型一样测试你的模型。例如:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"/path/to/your/model",
device_map="auto",
trust_remote_code=True
).eval()
response, history = model.chat(tokenizer, "你好", history=None)
print(response)
```
### 显存占用及训练速度
下面记录7B和14B模型在单GPU使用LoRALoRA (emb)指的是embedding和输出层参与训练而LoRA则不优化这部分参数和QLoRA时处理不同长度输入的显存占用和训练速度的情况。本次评测运行于单张A100-SXM4-80G GPU使用CUDA 11.8和Pytorch 2.0并使用了flash attention 2。我们统一使用batch size为1gradient accumulation为8的训练配置记录输入长度分别为256、512、1024、2048、4096和8192的显存占用GB和训练速度s/iter。我们还使用2张A100测了Qwen-7B的全参数微调。受限于显存大小我们仅测试了256、512和1024token的性能。

@ -0,0 +1,96 @@
import argparse
import json
from typing import Dict
import logging
import torch
import transformers
from transformers import AutoTokenizer
from transformers.trainer_pt_utils import LabelSmoother
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int,
system_message: str = "You are a helpful assistant."
) -> Dict:
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
im_start = tokenizer.im_start_id
im_end = tokenizer.im_end_id
nl_tokens = tokenizer('\n').input_ids
_system = tokenizer('system').input_ids + nl_tokens
_user = tokenizer('user').input_ids + nl_tokens
_assistant = tokenizer('assistant').input_ids + nl_tokens
# Apply prompt templates
data = []
# input_ids, targets = [], []
for i, source in enumerate(sources):
source = source["conversations"]
if roles[source[0]["from"]] != roles["user"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
_input_id = tokenizer(role).input_ids + nl_tokens + \
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == '<|im_start|>user':
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
elif role == '<|im_start|>assistant':
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
assert len(input_id) == len(target)
input_id = torch.tensor(input_id[:max_len], dtype=torch.int)
target = torch.tensor(target[:max_len], dtype=torch.int)
data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))
return data
if __name__ == "__main__":
parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")
parser.add_argument("--model_name_or_path", type=str, help="model path")
parser.add_argument("--data_path", type=str, help="calibration data path")
parser.add_argument("--out_path", type=str, help="output path of the quantized model")
parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")
parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")
parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")
args = parser.parse_args()
quantize_config = BaseQuantizeConfig(
bits=args.bits,
group_size=args.group_size,
damp_percent=0.01,
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
static_groups=False,
sym=True,
true_sequential=True,
model_name_or_path=None,
model_file_base_name="model"
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eod_id
data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)
model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
model.quantize(data, cache_examples_on_gpu=False)
model.save_quantized(args.out_path, use_safetensors=True)
tokenizer.save_pretrained(args.out_path)
Loading…
Cancel
Save