add finetuning

main
JustinLin610 1 year ago
parent 7ba36aad16
commit af22d5e0ce

@ -38,18 +38,18 @@ The following sections include information that you might find it helpful. Speci
In general, Qwen-7B outperforms the baseline models of a similar model size, and even outperforms larger models of around 13B parameters, on a series of benchmark datasets, e.g., MMLU, C-Eval, GSM8K, HumanEval, and WMT22, CMMLU, etc., which evaluate the models' capabilities on natural language understanding, mathematic problem solving, coding, etc. See the results below.
| Model | MMLU | C-Eval | GSM8K | HumanEval | WMT22 (en-zh) | CMMLU |
| :------------- | :--------: | :--------: | :--------: | :---------: | :-------------: | :--------: |
| LLaMA-7B | 35.1 | - | 11.0 | 10.5 | 8.7 | - |
| LLaMA 2-7B | 45.3 | - | 14.6 | 12.8 | 17.9 | - |
| Baichuan-7B | 42.3 | 42.8 | 9.7 | 9.2 | 26.6 | 44.4 |
| ChatGLM2-6B | 47.9 | 51.7 | 32.4 | 9.2 | - | 48.8 |
| InternLM-7B | 51.0 | 52.8 | 31.2 | 10.4 | 14.8 | - |
| Baichuan-13B | 51.6 | 53.6 | 26.6 | 12.8 | 30.0 | 55.8 |
| LLaMA-13B | 46.9 | 35.5 | 17.8 | 15.8 | 12.0 | - |
| LLaMA 2-13B | 54.8 | - | 28.7 | 18.3 | 24.2 | - |
| ChatGLM2-12B | 56.2 | **61.6** | 40.9 | - | - | - |
| **Qwen-7B** | **56.7** | 59.6 | **51.6** | **24.4** | **30.6** | **58.8** |
Ω| Model | MMLU | C-Eval | GSM8K | HumanEval | WMT22 (en-zh) | CMMLU |
|:------------------|:--------:|:--------:|:--------:|:---------:|:-------------:|:--------:|
| LLaMA-7B | 35.1 | - | 11.0 | 10.5 | 8.7 | - |
| LLaMA 2-7B | 45.3 | - | 14.6 | 12.8 | 17.9 | - |
| Baichuan-7B | 42.3 | 42.8 | 9.7 | 9.2 | 26.6 | 44.4 |
| ChatGLM2-6B | 47.9 | 51.7 | 32.4 | 9.2 | - | 48.8 |
| InternLM-7B | 51.0 | 52.8 | 31.2 | 10.4 | 14.8 | - |
| Baichuan-13B | 51.6 | 53.6 | 26.6 | 12.8 | 30.0 | 55.8 |
| LLaMA-13B | 46.9 | 35.5 | 17.8 | 15.8 | 12.0 | - |
| LLaMA 2-13B | 54.8 | - | 28.7 | 18.3 | 24.2 | - |
| ChatGLM2-12B | 56.2 | **61.6** | 40.9 | - | - | - |
| **Qwen-7B** | **56.7** | 59.6 | **51.6** | **24.4** | **30.6** | **58.8** |
<p align="center">
<img src="assets/performance.png" width="1000"/>
@ -65,6 +65,7 @@ For more experimental results (detailed model performance on more benchmark data
* python 3.8 and above
* pytorch 1.12 and above, 2.0 and above are recommended
* transformers 4.32 and above
* CUDA 11.4 and above are recommended (this is for GPU users, flash-attention users, etc.)
<br>
@ -237,16 +238,16 @@ response, history = model.chat(tokenizer, "Hi", history=None)
We illustrate the model performance of both BF16 and Int4 models on the benchmark, and we find that the quantized model does not suffer from significant performance degradation. Results are shown below:
| Quantization | MMLU | CEval (val) | GSM8K | Humaneval |
| -------------- | :----: | :-----------: | :-----: | :---------: |
| BF16 | 53.9 | 54.2 | 41.1 | 24.4 |
| Int4 | 52.6 | 52.9 | 38.1 | 23.8 |
|--------------|:----:|:-----------:|:-----:|:---------:|
| BF16 | 53.9 | 54.2 | 41.1 | 24.4 |
| Int4 | 52.6 | 52.9 | 38.1 | 23.8 |
### Inference Speed
We measured the average inference speed (tokens/s) of generating 2048 and 8192 tokens under BF16 precision and Int4 quantization, respectively.
| Quantization | Speed (2048 tokens) | Speed (8192 tokens) |
| -------------- | :-------------------: | :-------------------: |
|--------------|:-------------------:|:-------------------:|
| BF16 | 30.34 | 29.32 |
| Int4 | 43.56 | 33.92 |
@ -257,13 +258,87 @@ In detail, the setting of profiling is generating 8192 new tokens with 1 context
We also profile the peak GPU memory usage for encoding 2048 tokens as context (and generating single token) and generating 8192 tokens (with single token as context) under BF16 or Int4 quantization level, respectively. The results are shown below.
| Quantization | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
| -------------- | :-----------------------------------: | :-------------------------------------: |
|--------------|:-----------------------------------:|:-------------------------------------:|
| BF16 | 17.66GB | 22.58GB |
| Int4 | 8.21GB | 13.62GB |
The above speed and memory profiling are conducted using [this script](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py).
<br>
## Finetuning
Now we provide the official training script, `finetune.py`, for users to finetune the pretrained model for downstream applications in a simple fashion. Additionally, we provide shell scripts to launch finetuning with no worries. This script supports the training with [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/). The shell scripts that we provide use DeepSpeed, and thus we advise you to install DeepSpeed before you start.
To prepare your training data, you need to put all the samples into a list and save it to a json file. Each sample is a dictionary consisting of an id and a list for conversation. Below is a simple example list with 1 sample:
```json
[
{
"id": "identity_0",
"conversations": [
{
"from": "user",
"value": "你好",
},
{
"from": "assistant",
"value": "我是一个语言模型,我叫通义千问。"
}
]
}
]
```
After data preparation, you can use the provided shell scripts to run finetuning. Remember to specify the path to the data file, `$DATA`.
The finetuning scripts allow you to perform:
- Full-parameter finetuning
- LoRA
- Q-LoRA
Full-parameter parameter finetuning requires updating all parameters in the whole training process. To launch your training, run the following script:
```bash
# Distributed training. We do not provide single-GPU training script as the insufficient GPU memory will break down the training.
sh finetune/finetune_ds.sh
```
Remember to specify the correct model name or path, the data path, as well as the output directory in the shell scripts. Another thing to notice is that we use DeepSpeed ZeRO 3 in this script. If you want to make changes, just remove the argument `--deepspeed` or make changes in the DeepSpeed configuration json file based on your requirements. Additionally, this script supports mixed-precision training, and thus you can use `--bf16 True` or `--fp16 True`. Empirically we advise you to use bf16 to make your training consistent with our pretraining and alignment if your machine supports bf16, and thus we use it by default.
Similarly, to run LoRA, use another script to run as shown below. Before you start, make sure that you have installed `peft`. Also, you need to specify your paths to your model, data, and output. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load. Also, this script support both bf16 and fp16.
```bash
# Single GPU training
sh finetune/finetune_lora_single_gpu.sh
# Distributed training
sh finetune/finetune_lora_ds.sh
```
In comparison with full-parameter finetuning, LoRA ([paper](https://arxiv.org/abs/2106.09685)) only updates the parameters of adapter layers but keeps the original large language model layers frozen. This allows much fewer memory costs and thus fewer computation costs. However, if you still suffer from insufficient memory, you can consider Q-LoRA ([paper](https://arxiv.org/abs/2305.14314)), which uses the quantized large language model and other techniques such as paged attention to allow even fewer memory costs. To run Q-LoRA, directly run the following script:
```bash
# Single GPU training
sh finetune/finetune_qlora_single_gpu.sh
# Distributed training
sh finetune/finetune_qlora_ds.sh
```
For Q-LoRA, we advise you to load our provided quantized model, e.g., Qwen-7B-Chat-Int4. However, different from full-parameter finetuning and LoRA, only fp16 is supported for Q-LoRA.
Different from full-parameter finetuning, the training of both LoRA and Q-LoRA only saves the adapter parameters. Suppose your training starts from Qwen-7B, you can load the finetuned model for inference as shown below:
```python
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
path_to_adapter, # path to the output directory
device_map="auto",
trust_remote_code=True
).eval()
```
The shell scripts uses `torchrun` to run single-GPU or multi-GPU training. For multi-GPU training, you need to specify the proper hyperparameters for distributed training based on your machine.
## Demo
### Web UI
@ -379,21 +454,21 @@ Then you can run the 7B chat model on 2 GPUs using the above scripts.
Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
|:-----------------| :-----------------------: | :----------------------: | :----------------------: |
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
|:-----------------|:----------------------:|:---------------------:|:---------------------:|
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B-Chat** | **99%** | 0.89 | **9.7%** |
For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks.
Additionally, we provide experimental results to show its capabilities of playing as an agent. See [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) for more information. Its performance on the run-mode benchmark provided by Hugging Face is as follows:
| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
|:-----------------| :----------------: | :-----------: | :---------: |
| GPT-4 | **100** | **100** | **97.41** |
| GPT-3.5 | 95.37 | 96.30 | 87.04 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
|:-----------------|:---------------:|:----------:|:---------:|
| GPT-4 | **100** | **100** | **97.41** |
| GPT-3.5 | 95.37 | 96.30 | 87.04 |
| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
| **Qwen-7B-Chat** | 90.74 | 92.59 | 74.07 |
<br>

@ -259,19 +259,94 @@ response, history = model.chat(tokenizer, "Hi", history=None)
上述性能测算使用[此脚本](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py)完成。
<br>
## 微调
我们提供了`finetune.py`这个脚本供用户实现在自己的数据上进行微调的功能以接入下游任务。此外我们还提供了shell脚本减少用户的工作量。这个脚本支持 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 和 [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/) 。我们提供的shell脚本使用了DeepSpeed因此建议您确保已经安装DeepSpeed。
首先你需要准备你的训练数据。你需要将所有样本放到一个列表中并存入json文件中。每个样本对应一个字典包含id和conversation其中后者为一个列表。示例如下所示
```json
[
{
"id": "identity_0",
"conversations": [
{
"from": "user",
"value": "你好",
},
{
"from": "assistant",
"value": "我是一个语言模型,我叫通义千问。"
}
]
}
]
```
准备好数据后你可以使用我们提供的shell脚本实现微调。注意你需要在脚本中指定你的数据的路径。
微调脚本能够帮你实现:
- 全参数微调
- LoRA
- Q-LoRA
全参数微调在训练过程中更新所有参数。你可以运行这个脚本开始训练:
```bash
# 分布式训练。由于显存限制将导致单卡训练失败,我们不提供单卡训练脚本。
sh finetune/finetune_ds.sh
```
尤其注意你需要在脚本中指定正确的模型名称或路径、数据路径、以及模型输出的文件夹路径。在这个脚本中我们使用了DeepSpeed ZeRO 3。如果你想修改这个配置可以删除掉`--deepspeed`这个输入或者自行根据需求修改DeepSpeed配置json文件。此外我们支持混合精度训练因此你可以设置`--bf16 True`或者`--fp16 True`。经验上如果你的机器支持bf16我们建议使用bf16这样可以和我们的预训练和对齐训练保持一致这也是为什么我们把默认配置设为它的原因。
运行LoRA的方法类似全参数微调。但在开始前请确保已经安装`peft`代码库。另外记住要设置正确的模型、数据和输出路径。我们建议你为模型路径使用绝对路径。这是因为LoRA仅存储adapter部分参数而adapter配置json文件记录了预训练模型的路径用于读取预训练模型权重。同样你可以设置bf16或者fp16。
```bash
# 单卡训练
sh finetune/finetune_lora_single_gpu.sh
# 分布式训练
sh finetune/finetune_lora_ds.sh
```
与全参数微调不同LoRA ([论文](https://arxiv.org/abs/2106.09685)) 只更新adapter层的参数而无需更新原有语言模型的参数。这种方法允许用户用更低的显存开销来训练模型也意味着更小的计算开销。然而如果你依然遇到显存不足的问题可以考虑使用Q-LoRA ([论文](https://arxiv.org/abs/2305.14314))。该方法使用4比特量化模型以及paged attention等技术实现更小的显存开销。运行Q-LoRA你只需运行如下脚本
```bash
# 单卡训练
sh finetune/finetune_qlora_single_gpu.sh
# 分布式训练
sh finetune/finetune_qlora_ds.sh
```
我们建议你使用我们提供的Int4量化模型进行训练即Qwen-7B-Chat-Int4。然而与全参数微调以及LoRA不同Q-LoRA仅支持fp16。
与全参数微调不同LoRA和Q-LoRA的训练只需存储adapter部分的参数。假如你需要使用LoRA训练后的模型你需要使用如下方法。假设你使用Qwen-7B训练模型你可以用如下代码读取模型
```python
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
path_to_adapter, # path to the output directory
device_map="auto",
trust_remote_code=True
).eval()
```
上述shell脚本使用`torchrun`来运行单GPU和多GPU训练。分布式训练需要根据你的需求和机器指定正确的分布式训练超参数。
## Demo
### Web UI
我们提供了Web UI的demo供用户使用 (感谢 @wysaid 支持)。在开始前,确保已经安装如下代码库:
```
```bash
pip install -r requirements_web_demo.txt
```
随后运行如下命令,并点击生成链接:
```
```bash
python web_demo.py
```

@ -263,6 +263,80 @@ BF16 の精度と Int4 の量子化レベルの下で、それぞれ 2048 個と
上記のスピードとメモリーのプロファイリングは、[このスクリプト](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py)を使用しています。
<br>
## ファインチューニング
現在、公式のトレーニングスクリプト `finetune.py` を提供しています。さらに、finetune.pyのシェルスクリプトを提供し、finetune.pyを実行することで、finetune.pyを起動することができる。さらに、安心してファインチューニングを開始するためのシェルスクリプトも提供しています。このスクリプトは、[DeepSpeed](https://github.com/microsoft/DeepSpeed) および [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/) を使用したトレーニングをサポートします。弊社が提供するシェル・スクリプトは DeepSpeed を使用するため、事前に DeepSpeed をインストールすることをお勧めします:
学習データを準備するには、すべてのサンプルをリストにまとめ、jsonファイルに保存する必要があります。各サンプルはidと会話リストで構成される辞書です。以下は1つのサンプルを含む単純なリストの例です
```json
[
{
"id": "identity_0",
"conversations": [
{
"from": "user",
"value": "你好",
},
{
"from": "assistant",
"value": "我是一个语言模型,我叫通义千问。"
}
]
}
]
```
データ準備の後、提供されているシェルスクリプトを使って微調整を実行することができる。データファイルのパス `$DATA` を忘れずに指定してください。
ファインチューニングのスクリプトを使用することで、以下のことが可能になる:
- フルパラメーター・ファインチューニング
- LoRA
- Q-LoRA
フルパラメータパラメータのファインチューニングを行うには、トレーニングプロセス全体ですべてのパラメータを更新する必要があります。トレーニングを開始するには、以下のスクリプトを実行します:
```bash
# 分散トレーニング。GPUメモリが不足するとトレーニングが破綻するため、シングルGPUのトレーニングスクリプトは提供していません。
sh finetune/finetune_ds.sh
```
シェルスクリプトでは、正しいモデル名またはパス、データパス、出力ディレクトリを指定することを忘れないでください。このスクリプトでは DeepSpeed ZeRO 3 を使用しています。変更したい場合は、引数 `--deepspeed` を削除するか、要件に基づいて DeepSpeed 設定 json ファイルを変更してください。さらに、このスクリプトは混合精度のトレーニングに対応しており、`--bf16 True` または `--fp16 True` を使用することができます。経験的に、あなたのマシンがbf16をサポートしている場合、私たちのプリトレーニングとアライメントを整合させるためにbf16を使用することをお勧めします。
同様に、LoRAを実行するには、以下のように別のスクリプトを使って実行する。始める前に、`peft`がインストールされていることを確認してください。また、モデル、データ、出力へのパスを指定する必要があります。学習済みモデルには絶対パスを使用することをお勧めします。なぜなら、LoRAはアダプタのみを保存し、アダプタ設定jsonファイルの絶対パスは、ロードする事前学習済みモデルを見つけるために使用されるからです。また、このスクリプトはbf16とfp16の両方をサポートしている。
```bash
# シングルGPUトレーニング
sh finetune/finetune_lora_single_gpu.sh
# 分散トレーニング
sh finetune/finetune_lora_ds.sh
```
LoRA ([論文](https://arxiv.org/abs/2106.09685)) は、フルパラメーターによるファインチューニングと比較して、adapterのパラメーターを更新するだけで、元の大きな言語モデル層は凍結されたままである。そのため、メモリコストが大幅に削減でき、計算コストも削減できる。しかし、それでもメモリ不足に悩む場合は、Q-LoRA[論文](https://arxiv.org/abs/2305.14314)を検討することができます。これは、量子化されたラージ言語モデルと、ページド・アテンションなどの他のテクニックを使用し、さらに少ないメモリコストで実行することができます。Q-LoRAを実行するには、以下のスクリプトを直接実行してください
```bash
# シングルGPUトレーニング
sh finetune/finetune_qlora_single_gpu.sh
# 分散トレーニング
sh finetune/finetune_qlora_ds.sh
```
Q-LoRAについては、弊社が提供する量子化モデル、例えばQwen-7B-Chat-Int4をロードすることをお勧めします。ただし、フルパラメータ・ファインチューニングやLoRAとは異なり、Q-LoRAではfp16のみがサポートされる。
LoRAとQ-LoRAの学習は、フルパラメータによるファインチューニングとは異なり、アダプターパラメータのみを保存する。仮にQwen-7Bから学習を開始したとすると、以下のようにファインチューニングされたモデルを読み込んで推論を行うことができる
```python
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
path_to_adapter, # path to the output directory
device_map="auto",
trust_remote_code=True
).eval()
```
シェルスクリプトは`torchrun`を使用してシングルGPUまたはマルチGPUトレーニングを実行します。そのため、分散トレーニングのための適切なハイパーパラメータをマシンに応じて指定する必要があります。
## デモ
### ウェブ UI

@ -0,0 +1,353 @@
# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.
from dataclasses import dataclass, field
import json
import math
import logging
import os
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import Trainer, GPTQConfig, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."}
)
lazy_preprocess: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=8192,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = False
@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: List[str] = field(
default_factory=lambda: ["c_attn", "c_proj", "w1", "w2"]
)
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
else:
if trainer.args.use_lora:
state_dict = get_peft_state_maybe_zero_3(
trainer.model.named_parameters(), bias
)
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer._save(output_dir, state_dict=state_dict)
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int,
system_message: str = "You are a helpful assistant."
) -> Dict:
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
im_start = tokenizer.im_start_id
im_end = tokenizer.im_end_id
nl_tokens = tokenizer('\n').input_ids
_system = tokenizer('system').input_ids + nl_tokens
_user = tokenizer('user').input_ids + nl_tokens
_assistant = tokenizer('assistant').input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["user"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
_input_id = tokenizer(role).input_ids + nl_tokens + \
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == '<|im_start|>user':
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
elif role == '<|im_start|>assistant':
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
assert len(input_id) == len(target)
input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
target += [IGNORE_TOKEN_ID] * (max_len - len(target))
input_ids.append(input_id[:max_len])
targets.append(target[:max_len])
input_ids = torch.tensor(input_ids, dtype=torch.int)
targets = torch.tensor(targets, dtype=torch.int)
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...")
sources = [example["conversations"] for example in raw_data]
data_dict = preprocess(sources, tokenizer, max_len)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.attention_mask = data_dict["attention_mask"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(
input_ids=self.input_ids[i],
labels=self.labels[i],
attention_mask=self.attention_mask[i],
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
self.max_len = max_len
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)
ret = dict(
input_ids=ret["input_ids"][0],
labels=ret["labels"][0],
attention_mask=ret["attention_mask"][0],
)
self.cached_data_dict[i] = ret
return ret
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, max_len,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = (
LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
)
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
local_rank = training_args.local_rank
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP or ZeRO3 are not incompatible with QLoRA."
)
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
trust_remote_code=True,
)
config.use_cache = False
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
device_map=device_map,
trust_remote_code=True,
quantization_config=GPTQConfig(
bits=4, disable_exllama=True
)
if training_args.use_lora and lora_args.q_lora
else None,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
trust_remote_code=True,
)
tokenizer.pad_token_id = tokenizer.eod_id
if training_args.use_lora:
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM",
modules_to_save=["wte", "lm_head"] # This argument serves for adding new tokens.
)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, lora_config)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
# Load data
data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length
)
# Start trainner
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
if __name__ == "__main__":
train()

@ -0,0 +1,52 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

@ -0,0 +1,59 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

@ -0,0 +1,47 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
GPUS_PER_NODE=8
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="Qwen/Qwen-7B" # Set the path if you do not want to load from huggingface directly
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path_to_data"
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--data_path $DATA \
--bf16 True \
--output_dir output_qwen \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True \
--deepspeed finetune/ds_config_zero3.json

@ -0,0 +1,47 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
GPUS_PER_NODE=8
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="Qwen/Qwen-7B" # Set the path if you do not want to load from huggingface directly
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path_to_data"
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--data_path $DATA \
--bf16 True \
--output_dir output_qwen \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--lazy_preprocess True \
--use_lora \
--deepspeed finetune/ds_config_zero2.json

@ -0,0 +1,35 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
MODEL="Qwen/Qwen-7B" # Set the path if you do not want to load from huggingface directly
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path_to_data"
export CUDA_VISIBLE_DEVICES=0
python finetune.py \
--model_name_or_path $MODEL \
--data_path $DATA \
--bf16 True \
--output_dir output_qwen \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--lazy_preprocess True \
--use_lora

@ -0,0 +1,49 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
GPUS_PER_NODE=8
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="Qwen/Qwen-7B-Chat-Int4" # Set the path if you do not want to load from huggingface directly
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path_to_data"
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
# Remember to use --fp16 instead of --bf16 due to autogptq
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--data_path $DATA \
--fp16 True \
--output_dir output_qwen \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--lazy_preprocess True \
--use_lora \
--q_lora \
--deepspeed finetune/ds_config_zero2.json

@ -0,0 +1,36 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
MODEL="Qwen/Qwen-7B-Chat-Int4" # Set the path if you do not want to load from huggingface directly
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path_to_data"
export CUDA_VISIBLE_DEVICES=0
# Remember to use --fp16 instead of --bf16 due to autogptq
python finetune.py \
--model_name_or_path $MODEL \
--data_path $DATA \
--fp16 True \
--output_dir output_qwen \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--lazy_preprocess True \
--use_lora \
--q_lora

@ -3,4 +3,6 @@ accelerate
tiktoken
einops
transformers_stream_generator==0.0.4
scipy
scipy
peft
deepspeed
Loading…
Cancel
Save