You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

434 lines
19 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "0e7993c3-3999-4ac5-b1dc-77875d80e4c8",
"metadata": {},
"source": [
"# Fine-tuning Qwen-7B-Chat on Your Own Domain-Specific Data\n",
"\n",
"This notebook uses Qwen-7B-Chat as an example to introduce how to LoRA fine-tune the Qwen model on a specific domain.\n",
"\n",
"Qwen is a pretrained conversational model that supports English and Chinese. It is suitable for universal scenarios, but may lack some specialized knowledge in certain specific domain. If you would like to fine-tune it for a specific domain, or on your own private dataset, you can refer to this tutorial.\n",
"\n",
"Here is an example showing the differences before and after fine-tuning.\n",
"\n",
"#### Model Input\n",
">怎样治疗青光眼?\n",
"\n",
"#### Before Fine-tuning\n",
"\n",
">青光眼的治疗方法取决于病情严重程度。轻度青光眼可以采取药物治疗,比如激素类药物、抗生素等,以减轻炎症和改善血管紧张。\n",
">\n",
">中度青光眼可能需要激光手术来降低眼压,并使用药物以保持控制。\n",
">\n",
">重度青光眼则需要实施传统的眼部手术,如小梁切除术或巩膜外减压术,以防止视神经损伤。\n",
"\n",
"#### After Fine-tuning\n",
"\n",
">青光眼是一种眼睛疾病,可以引起视神经的损伤和视力下降。治疗青光眼的主要方法是药物治疗、激光治疗和手术治疗。\n",
">1. 药物治疗:对于早期青光眼患者,通常采用抗青光眼药物治疗,通过降低眼压来预防视神经进一步损伤。目前常用的抗青光眼药物有拉坦前列素滴眼液、阿托品滴眼液等。\n",
">2. 激光治疗对于中晚期青光眼患者如果药物治疗效果不佳可能需要进行激光治疗。目前常用的激光治疗方法包括小梁切除术、YAG激光手术等。\n",
">3. 手术治疗:对于非常严重的青光眼患者,可能需要进行手术治疗。目前常用的手术治疗方法包括前房引流术、虹膜周切术、小梁切除术、巩膜滤过术等。\n",
">\n",
">需要注意的是,以上治疗方法的具体选择应根据患者的病情和医生的专业建议来进行。同时,青光眼患者在日常生活中也需要保持良好的生活习惯,避免过度疲劳、保持情绪稳定、定期检查眼睛等情况的发生。"
]
},
{
"cell_type": "markdown",
"id": "bdea7e21-fec8-49fe-b7ea-afde3f02738f",
"metadata": {},
"source": [
"## Environment Requirements\n",
"\n",
"Please refer to **requirements.txt** to install the required dependencies.\n",
"\n",
"Run the following command line in the main directory of the Qwen repo.\n",
"```bash\n",
"pip install -r requirements.txt\n",
"```\n",
"\n",
"\n",
"## Preparation\n",
"\n",
"### Download Qwen-7B-Chat\n",
"\n",
"First, download the model files. You can choose to download directly from ModelScope."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "248488f9-4a86-4f35-9d56-50f8e91a8f11",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"from modelscope.hub.snapshot_download import snapshot_download\n",
"model_dir = snapshot_download('Qwen/Qwen-7B-chat', cache_dir='.')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7b2a92b1-f08e-4413-9f92-8f23761e6e1f",
"metadata": {},
"source": [
"### Download Medical Training Data\n",
"\n",
"Download the data required for training; here, we provide a medical conversation dataset for training. It is sampled from [MedicalGPT repo](https://github.com/shibing624/MedicalGPT/) and we have converted this dataset into a format that can be used for fine-tuning.\n",
"\n",
"Disclaimer: the dataset can be only used for the research purpose."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce195f08-fbb2-470e-b6c0-9a03457458c7",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/qwen_recipes/medical_sft.json"
]
},
{
"cell_type": "markdown",
"id": "7226bed0-171b-4d45-a3f9-b3d81ec2bb9f",
"metadata": {},
"source": [
"You can prepare your dataset in JSON format following the format below, and then modify the `--data_path` parameter in the training command to point to your JSON file.\n",
"\n",
"These data instances can be conversations in the real world or include domain knowledge QA pairs. Besides, fine-tuning allows Qwen-chat to play like some specific roles. As Qwen-chat is a dialogue model for general scenarios, your fine-tuning can customize a chatbot to meet your requirements.\n",
"\n",
"We recommend that you prepare 50~ data instances if you want to fine-tune Qwen-chat as a roleplay model.\n",
"\n",
"You may prepare much more data instances if you want to infuse the domain knowledge of your field into the model.\n",
"\n",
"In this tutorial, we have prepared a medical domain fine-tuning dataset consisting of 1000 data instancess as an example. You can refer to our example to fine-tune on your own domain-specific dataset.\n",
"\n",
"Below is a simple example list with 1 sample:\n",
"\n",
"```python\n",
"[\n",
" {\n",
" \"id\": \"1\",\n",
" \"conversations\": [\n",
" {\n",
" \"from\": \"user\",\n",
" \"value\": \"直肠腺瘤的早期症状?\"\n",
" },\n",
" {\n",
" \"from\": \"assistant\",\n",
" \"value\": \"结直肠腺瘤是起源于结直肠黏膜腺上皮的良性肿瘤包括结肠腺瘤与直肠腺瘤是常见的肠道良性肿瘤。因与大肠癌的发生关系密切被认为是一种癌前病变。不同地区、不同年龄的发病率差别很大40岁以下的发病率低60岁以上较高男女无明显差别。此类疾病的病因及发病机制不明确。部分患者有遗传因素存在另外高脂肪饮食与食物纤维不足、肠道菌群紊乱也与此病的发生有相关性。多数患者无症状。症状与腺瘤的大小、部位、病理性质等有关。有症状者可有以下表现直肠腺瘤的临床常见症状有便血多为无疼性便血常呈鲜红色量多可引起贫血。粪便性状改变多为粘液便便意不尽或里急重感有时只排出粘液称为假性腹泻。1.腹部不适、腹部疼痛多为大的腺瘤伴发肠套叠、肠梗阻引起。2.排便习惯改变包括便秘、腹泻、里急后重等。3.粪便带血最常见为间歇性便血。4.部分位于直肠的较大的、带蒂腺瘤可在排便时脱落或脱出肛。可以采取内镜下高频电凝、激光、微波凝固等方法切除也可以选择外科手术切除并定期随访。有恶变者根据情况选择其他治疗如放疗、化疗、手术等管状腺瘤切除术后复发者少见但绒毛状腺瘤及绒毛管状腺瘤切除术后常可复发尤其是绒毛状腺瘤且多发的腺瘤复发率高于单发者。对于经内镜治疗或局部手术切除的结直肠腺瘤患者尤其是绒毛状腺瘤或广基的绒毛管状腺瘤患者建议腺瘤切除后的第一年内36个月进行一次肠镜检查第二年开始每年一次。\"\n",
" }\n",
" ]\n",
" }\n",
"]\n",
"```\n",
"\n",
"You can also use multi-turn conversations as the training set. Here is a simple example:\n",
"\n",
"```python\n",
"[\n",
" {\n",
" \"id\": \"2\",\n",
" \"conversations\": [\n",
" {\n",
" \"from\": \"user\",\n",
" \"value\": \"你好,我最近经常感觉胸口疼痛,这是怎么回事?\"\n",
" },\n",
" {\n",
" \"from\": \"assistant\",\n",
" \"value\": \"胸痛可能由多种原因引起,包括心脏问题、消化系统疾病、呼吸系统问题等。您能描述一下疼痛的性质和持续时间吗?\"\n",
" },\n",
" {\n",
" \"from\": \"user\",\n",
" \"value\": \"疼痛感觉像是挤压,大概持续了几分钟。\"\n",
" },\n",
" {\n",
" \"from\": \"assistant\",\n",
" \"value\": \"挤压感的胸痛可能与心脏问题有关,特别是如果它伴随着呼吸困难、出汗或恶心。我建议您尽快去看医生并进行适当的检查,如心电图和血液检测,以确定具体原因。\"\n",
" },\n",
" {\n",
" \"from\": \"user\",\n",
" \"value\": \"我应该去急诊室吗,还是预约我的家庭医生?\"\n",
" },\n",
" {\n",
" \"from\": \"assistant\",\n",
" \"value\": \"如果您的疼痛是突然发生的,并且还有其他症状,建议您立即去急诊室。如果疼痛不是很严重且没有其他严重症状,您可以预约家庭医生进一步评估。但请不要忽视疼痛,尤其是如果这种情况是第一次出现。\"\n",
" }\n",
" ]\n",
" }\n",
"]\n",
"```\n",
"\n",
"## Fine-Tune the Model\n",
"\n",
"You can directly run the prepared training script to fine-tune the model. \n",
"\n",
"For parameter settings, you can modify `--model_name_or_path` to the location of the model you want to fine-tune, and set `--data_path` to the location of the dataset.\n",
"\n",
"You should remove the `--bf16` parameter if you are using a non-Ampere architecture GPU, such as a V100. \n",
"\n",
"For `--model_max_length` and `--per_device_train_batch_size`, we recommend the following configurations, ,you can refer to [this document](../../finetune/deepspeed/readme.md) for more details:\n",
"\n",
"| --model_max_length | --per_device_train_batch_size | GPU Memory |\n",
"|-----------------|------------|--------------------|\n",
"| 512 | 4 | 24g |\n",
"| 1024 | 3 | 24g |\n",
"| 512 | 8 | 32g |\n",
"| 1024 | 6 | 32g |\n",
"\n",
"You can use our recommended saving parameters, or you can save by epoch by just setting `--save_strategy \"epoch\"` if you prefer to save at each epoch stage. `--save_total_limit` means the limit on the number of saved checkpoints.\n",
"\n",
"For other parameters, such as `--weight_decay` and `--adam_beta2`, we recommend using the values we provided blow.\n",
"\n",
"Setting the parameters `--gradient_checkpointing` and `--lazy_preprocess` is to save GPU memory.\n",
"\n",
"The parameters for the trained Lora module will be saved in the **output_qwen** folder."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ab0581e-be85-45e6-a5b7-af9c42ea697b",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"!python ../../../finetune/finetune.py \\\n",
" --model_name_or_path \"Qwen/Qwen-7B-Chat/\"\\\n",
" --data_path \"medical_sft.json\"\\\n",
" --bf16 \\\n",
" --output_dir \"output_qwen\" \\\n",
" --num_train_epochs 4\\\n",
" --per_device_train_batch_size 4 \\\n",
" --per_device_eval_batch_size 3 \\\n",
" --gradient_accumulation_steps 16 \\\n",
" --evaluation_strategy \"no\" \\\n",
" --save_strategy \"epoch\" \\\n",
" --save_steps 3000 \\\n",
" --save_total_limit 10 \\\n",
" --learning_rate 1e-5 \\\n",
" --weight_decay 0.1 \\\n",
" --adam_beta2 0.95 \\\n",
" --warmup_ratio 0.01 \\\n",
" --lr_scheduler_type \"cosine\" \\\n",
" --logging_steps 10 \\\n",
" --model_max_length 512 \\\n",
" --gradient_checkpointing \\\n",
" --lazy_preprocess \\\n",
" --use_lora"
]
},
{
"cell_type": "markdown",
"id": "5e6f28aa-1772-48ce-aa15-8cf29e7d67b5",
"metadata": {},
"source": [
"## Merge Weights\n",
"\n",
"The LoRA training only saves the adapter parameters. You can load the fine-tuned model and merge weights as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4fd5ef2a-34f9-4909-bebe-7b3b086fd16a",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-01-26T02:46:14.585746Z",
"iopub.status.busy": "2024-01-26T02:46:14.585089Z",
"iopub.status.idle": "2024-01-26T02:47:08.095464Z",
"shell.execute_reply": "2024-01-26T02:47:08.094715Z",
"shell.execute_reply.started": "2024-01-26T02:46:14.585720Z"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".\n",
"Try importing flash-attention for faster inference...\n",
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
"Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00, 1.14it/s]\n"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM\n",
"from peft import PeftModel\n",
"import torch\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen-7B-chat/\", torch_dtype=torch.float16, device_map=\"auto\", trust_remote_code=True)\n",
"model = PeftModel.from_pretrained(model, \"output_qwen/\")\n",
"merged_model = model.merge_and_unload()\n",
"merged_model.save_pretrained(\"output_qwen_merged\", max_shard_size=\"2048MB\", safe_serialization=True)"
]
},
{
"cell_type": "markdown",
"id": "2e3f5b9f-63a1-4599-8d9b-a8d8f764838f",
"metadata": {},
"source": [
"The tokenizer files are not saved in the new directory in this step. You can copy the tokenizer files or use the following code:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "10fa5ea3-dd55-4901-86af-c045d4c56533",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-01-26T02:47:08.097051Z",
"iopub.status.busy": "2024-01-26T02:47:08.096744Z",
"iopub.status.idle": "2024-01-26T02:47:08.591289Z",
"shell.execute_reply": "2024-01-26T02:47:08.590665Z",
"shell.execute_reply.started": "2024-01-26T02:47:08.097029Z"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"('output_qwen_merged/tokenizer_config.json',\n",
" 'output_qwen_merged/special_tokens_map.json',\n",
" 'output_qwen_merged/qwen.tiktoken',\n",
" 'output_qwen_merged/added_tokens.json')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\n",
" \"Qwen/Qwen-7B-chat/\",\n",
" trust_remote_code=True\n",
")\n",
"\n",
"tokenizer.save_pretrained(\"output_qwen_merged\")"
]
},
{
"cell_type": "markdown",
"id": "804b84d8",
"metadata": {},
"source": [
"## Test the Model\n",
"\n",
"After merging the weights, we can test the model as follows:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "dbae310c",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-01-26T02:48:29.995040Z",
"iopub.status.busy": "2024-01-26T02:48:29.994448Z",
"iopub.status.idle": "2024-01-26T02:48:41.677104Z",
"shell.execute_reply": "2024-01-26T02:48:41.676591Z",
"shell.execute_reply.started": "2024-01-26T02:48:29.995019Z"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
"Loading checkpoint shards: 100%|██████████| 8/8 [00:04<00:00, 1.71it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"VDAC1电压依赖性钙通道是一种位于细胞膜上的钙离子通道负责将细胞内的钙离子释放到细胞外。它在神经信号传导、肌肉收缩和血管舒张中发挥着重要作用。\n",
"\n",
"VDAC1通常由4个亚基组成每个亚基都有不同的功能。其中一个亚基是内腔部分它与钙离子的结合有关另一个亚基是外腔部分它与离子通道的打开和关闭有关第三个亚基是一层跨膜蛋白它负责调节通道的开放程度最后一个亚基是一个膜骨架连接器它帮助维持通道的结构稳定性。\n",
"\n",
"除了钙离子外VDAC1还能够接收钾离子和氯离子等其他离子并将其从细胞内释放到细胞外。此外VDAC1还参与了许多细胞代谢反应例如脂肪酸合成和糖原分解等。\n",
"\n",
"总的来说VDAC1是细胞膜上的一种重要离子通道其作用涉及到许多重要的生物学过程。\n"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from transformers.generation import GenerationConfig\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"output_qwen_merged\", trust_remote_code=True)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"output_qwen_merged\",\n",
" device_map=\"auto\",\n",
" trust_remote_code=True\n",
").eval()\n",
"\n",
"response, history = model.chat(tokenizer, \"什么是VDAC1\", history=None)\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "987f524d-6918-48ae-a730-f285cf6f8416",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}