From 851ccfcea77ef868ed82e18f475b9a206f2d2f74 Mon Sep 17 00:00:00 2001 From: hyperzlib Date: Thu, 25 Apr 2024 15:43:51 +0800 Subject: [PATCH] =?UTF-8?q?finetune=E6=B7=BB=E5=8A=A0system=5Fprompt?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli_demo.py | 78 +++++++++++++++++++-- examples/function_call_finetune_examples.py | 33 ++++----- finetune.py | 4 +- finetune/finetune_lora_single_gpu.sh | 2 +- 4 files changed, 95 insertions(+), 22 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index 4ff5b41..f312ca0 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -9,10 +9,13 @@ import argparse import os import platform import shutil +import json +import json5 from copy import deepcopy import torch from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import AutoPeftModelForCausalLM from transformers.generation import GenerationConfig from transformers.trainer_utils import set_seed @@ -40,8 +43,52 @@ Commands: :reset-conf Reset generation config 重置生成配置 ''' +TOOL_DESC = """{name_for_model}: 你可以调用该工具与 {name_for_human} API 进行交互。{name_for_human} API 有什么作用?{description_for_model} 参数列表:{parameters}""" + +REACT_INSTRUCTION = """请尽可能回答下列问题。您可以访问以下 API: + +{tools_text} + +使用以下格式回答问题: + +Question: 你需要回答的问题 +Thought: 你的思考过程 +Action: 要使用的操作,必须是 [{tools_name_text}] 其中之一 +Action Input: 操作的输入参数 +Observation: 操作的结果 +... (这些 Thought/Action/Action Input/Observation 可以是零次或重复多次) +Thought: 你的最终思考过程 +Final Answer: 你的最终回答""" + + +def build_react_instruction(functions: list[dict]): + tools_text = [] + tools_name_text = [] + for func_info in functions: + name = func_info.get("name", "") + name_m = func_info.get("name_for_model", name) + name_h = func_info.get("name_for_human", name) + desc = func_info.get("description", "") + desc_m = func_info.get("description_for_model", desc) + tool = TOOL_DESC.format( + name_for_model=name_m, + name_for_human=name_h, + description_for_model=desc_m, + parameters=json.dumps(func_info["parameters"], ensure_ascii=False), + ) + tools_text.append(tool) + tools_name_text.append(name_m) + tools_text = "\n\n".join(tools_text) + tools_name_text = ", ".join(tools_name_text) + instruction = REACT_INSTRUCTION.format( + tools_text=tools_text, + tools_name_text=tools_name_text, + ) + return instruction + def _load_model_tokenizer(args): + model_path = args.model_path or args.checkpoint_path tokenizer = AutoTokenizer.from_pretrained( args.checkpoint_path, trust_remote_code=True, resume_download=True, ) @@ -49,9 +96,9 @@ def _load_model_tokenizer(args): if args.cpu_only: device_map = "cpu" else: - device_map = "auto" + device_map = "cuda" - model = AutoModelForCausalLM.from_pretrained( + model = AutoPeftModelForCausalLM.from_pretrained( args.checkpoint_path, device_map=device_map, trust_remote_code=True, @@ -59,7 +106,7 @@ def _load_model_tokenizer(args): ).eval() config = GenerationConfig.from_pretrained( - args.checkpoint_path, trust_remote_code=True, resume_download=True, + model_path, trust_remote_code=True, resume_download=True, ) return model, tokenizer, config @@ -107,7 +154,13 @@ def main(): description='QWen-Chat command-line interactive chat demo.') parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, help="Checkpoint name or path, default to %(default)r") + parser.add_argument("-m", "--model-path", type=str, default=None, + help="Model name or path, default to None") parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") + parser.add_argument("-sf", "--system-prompt-file", type=str, default=None, + help="System prompt file, default to None") + parser.add_argument("-fd", "--function-definition", type=str, default=None, + help="Function definition file, should be json or json5, default to None") parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") args = parser.parse_args() @@ -116,11 +169,24 @@ def main(): model, tokenizer, config = _load_model_tokenizer(args) orig_gen_config = deepcopy(model.generation_config) + system_prompt = "You are a helpful assistant." + if args.system_prompt_file: + with open(args.system_prompt_file, 'r', encoding="utf-8") as f: + system_prompt = f.read() + + function_prompt = None + if args.function_definition: + with open(args.function_definition, 'r', encoding="utf-8") as f: + functions = json5.load(f) + function_prompt = build_react_instruction(functions) + _clear_screen() print(_WELCOME_MSG) seed = args.seed + is_first_msg = True + while True: query = _get_input() @@ -195,7 +261,11 @@ def main(): # Run chat. set_seed(seed) try: - for response in model.chat_stream(tokenizer, query, history=history, generation_config=config): + prompt = query + if function_prompt: + prompt = f"{function_prompt}\n\nQuestion: {query}" + + for response in model.chat_stream(tokenizer, prompt, history=history, system=system_prompt, generation_config=config): _clear_screen() print(f"\nUser: {query}") print(f"\nQwen-Chat: {response}") diff --git a/examples/function_call_finetune_examples.py b/examples/function_call_finetune_examples.py index 79e3509..c660a5c 100644 --- a/examples/function_call_finetune_examples.py +++ b/examples/function_call_finetune_examples.py @@ -88,36 +88,37 @@ def main(): example_functions = [ { - "name_for_human": "Google Search", - "name_for_model": "google_search", - "description_for_model": "Google Search is a general search engine that can be used to access the internet," - + " query encyclopedia knowledge, and stay informed about current events." - + " Format the arguments as a JSON object.", # If you expect `Action Input` to be a JSON. + "name_for_human": "在线搜索", + "name_for_model": "search", + "description_for_model": "使用此工具可以搜索群内成员的信息,或者在互联网上搜索信息。Format the arguments as a JSON object.", # If you expect `Action Input` to be a JSON. "parameters": [ { - "name": "search_query", - "description": "Search keywords or phrases", - "required": True, # Set to False if it is an optional parameter. + "name": "keywords", + "description": "需要搜索的关键词。", + "required": True, "schema": {"type": "string"}, }, - # You can add more parameters to this `parameters` list if you wish. ], }, { - "name_for_human": "Code Interpreter", - "name_for_model": "code_interpreter", - "description_for_model": "Code interpreter that can execute Python code." - + "Enclose the code within triple backticks (`)" - + " at the beginning and end of the code.", # If you expect `Action Input` to be a Markdown code block. + "name_for_human": "天气信息", + "name_for_model": "get_weather", + "description_for_model": "查询某个地点的天气信息。Format the arguments as a JSON object.", "parameters": [ { - "name": "code", - "description": "Code to be executed", + "name": "position", + "description": "需要查询天气的地点。", "required": True, "schema": {"type": "string"}, }, ], }, + { + "name_for_human": "封禁用户", + "name_for_model": "ban_user", + "description_for_model": "在用户多次触犯道德规范时,使用此工具可以封禁用户。", + "parameters": [], + } ] example_instruction = build_react_instruction(example_functions) diff --git a/finetune.py b/finetune.py index 4a4e334..a6815d6 100644 --- a/finetune.py +++ b/finetune.py @@ -220,8 +220,10 @@ class LazySupervisedDataset(Dataset): def __getitem__(self, i) -> Dict[str, torch.Tensor]: if i in self.cached_data_dict: return self.cached_data_dict[i] + + system_prompt: str = self.raw_data[i].get("system") or "You are a helpful assistant." - ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len) + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len, system_prompt) ret = dict( input_ids=ret["input_ids"][0], labels=ret["labels"][0], diff --git a/finetune/finetune_lora_single_gpu.sh b/finetune/finetune_lora_single_gpu.sh index 74d9d36..972bcf0 100644 --- a/finetune/finetune_lora_single_gpu.sh +++ b/finetune/finetune_lora_single_gpu.sh @@ -4,7 +4,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 MODEL="Qwen/Qwen-7B" # Set the path if you do not want to load from huggingface directly # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. # See the section for finetuning in README for more information. -DATA="path_to_data" +DATA="../dataset/kurita.json" function usage() { echo '