From b7767361cf892c2af5f666409d01eb89a34e0852 Mon Sep 17 00:00:00 2001 From: "tujianhong.tjh" Date: Mon, 21 Aug 2023 13:33:15 +0800 Subject: [PATCH] add example: auto_comments --- examples/auto_comments.md | 59 ++++++++++++ examples/auto_comments.py | 189 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 examples/auto_comments.md create mode 100644 examples/auto_comments.py diff --git a/examples/auto_comments.md b/examples/auto_comments.md new file mode 100644 index 0000000..4aadd90 --- /dev/null +++ b/examples/auto_comments.md @@ -0,0 +1,59 @@ +# Auto Comments +本文档介绍Auto Comments,这是一个利用Qwen模型为代码文件自动生成注释的使用案例。 + +# 使用方法 +您可以直接执行如下命令,为提供的代码文件生成注释: +``` +python auto_comments.py --path 'path of file or folder' +``` + +参数: +- path:文件路径。可以是文件(目前支持python代码文件),也可以是文件夹(会扫描文件夹下所有python代码文件) +- regenerate:重新生成。默认False,如果针对同一文件需要重新生成注释,请设置为True + +# 使用样例 +- 执行:python auto_comments.py --path test_file.py +- test_file.py 内容为: +``` +import numpy as np +import pandas as pd +import seaborn as sns +sns.set_theme(style="whitegrid") + +rs = np.random.RandomState(365) +values = rs.randn(365, 4).cumsum(axis=0) +dates = pd.date_range("1 1 2016", periods=365, freq="D") +data = pd.DataFrame(values, dates, columns=["A", "B", "C", "D"]) +data = data.rolling(7).mean() + +sns.lineplot(data=data, palette="tab10", linewidth=2.5) +``` + +- 输出:test_file_comments.py(包含注释的代码文件),文件内容如下: +``` +# 导入需要的库 +import numpy as np +import pandas as pd +import seaborn as sns + +# 设置 Seaborn 的主题风格为白色网格 +sns.set_theme(style="whitegrid") + +# 生成随机数 +rs = np.random.RandomState(365) + +# 生成 365 行 4 列的随机数,并按行累加 +values = rs.randn(365, 4).cumsum(axis=0) + +# 生成日期 +dates = pd.date_range("1 1 2016", periods=365, freq="D") + +# 将随机数和日期组合成 DataFrame +data = pd.DataFrame(values, dates, columns=["A", "B", "C", "D"]) + +# 对 DataFrame 进行 7 天滑动平均 +data = data.rolling(7).mean() + +# 使用 Seaborn 绘制折线图 +sns.lineplot(data=data, palette="tab10", linewidth=2.5) +``` diff --git a/examples/auto_comments.py b/examples/auto_comments.py new file mode 100644 index 0000000..dcda959 --- /dev/null +++ b/examples/auto_comments.py @@ -0,0 +1,189 @@ +# 运行方式:python auto_comments.py --path 'path of file or folder' +# 脚本功能:使用QWen-7B-Chat为提供的代码文件自动生成注释。(详见auto_comments.md) + + +import argparse +import os +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + +MaxLine = 50 # 限制单次处理最大代码行数 +SplitKey = ["\ndef "] # 自定义的切分代码标识 +CodeFileType = ["py"] # 目前仅测试过对python文件生成注释 + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default='Qwen-7B/eval/evaluate_ceval.py') + parser.add_argument('--regenerate', action='store_true', default=False) #如果已经生成过注释,默认不会重新生成 + args = parser.parse_args() + return args + +class QWenChat(): + def __init__(self): + self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) + + # use bf16 + # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval() + # use fp16 + # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() + # use cpu only + # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval() + # use auto mode, automatically select precision based on the device. + self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval() + + # Specify hyperparameters for generation + self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) + self.history = None + + def chat(self, query, system = ""): + + # use history + # response, history = self.model.chat(self.tokenizer, query, history=self.history) + + # 默认不使用history + response, history = self.model.chat(self.tokenizer, query, history=None) + self.history = history + + return response +# 生成注释 +def gen_code_comments(context, model = None, **kwargs): + prompt = "\n为以上代码生成细致的中文注释,注意使用合适的语法。要求必须在每个函数开头生成一段统一的函数功能注释。\n除了注释,请保证原始代码内容不变。不要返回除了注释和代码以外的其余信息,不要生成额外代码。\n" + return model.chat(context + prompt) + +def read_file(path): + f = open(path, "r",encoding='utf-8') + lines = f.readlines() + return "".join(lines) + +def write_file(path, context): + with open(path,'w') as f: + f.write(context) + +# 如果代码文件过长,可以简单按照最大行数切分代码 +def split_context_by_maxline(text): + lines = text.split("\n") + lines_len = len(lines) + res = [] + for i in range(MaxLine, lines_len, MaxLine): + res.append("\n".join(lines[i-MaxLine:i])) + + if i < lines_len: + res.append("\n".join(lines[i:])) + return res + +# 如果代码文件过长,可以简单按照函数切分代码 +def split_context_by_splitkey(text): + blocks = text.split(SplitKey[0]) + return [blocks[0]] + [SplitKey[0]+x for x in blocks[1:]] + +# merge原始代码和生成的注释,目的是保证原始代码不被更改。这部分可以使用各种不同的策略处理。 +def merge_code_and_comments(original_file, comments_path): + res = [] + ori_f = open(original_file, "r",encoding='utf-8') + ori_lines = ori_f.readlines() + + com_f = open(comments_path, "r",encoding='utf-8') + com_lines = com_f.readlines() + len_com_lines = len(com_lines) + p = 0 + j = 0 + for i, line in enumerate(ori_lines): + if line.isspace(): + continue + if line.strip()[0] == '#': + res.append(line) + continue + while j < len_com_lines and line[:-1] not in com_lines[j]: + j += 1 + if j < len_com_lines: + p = j - 1 + up_comments = [] + triple_dot_flag = 0 + while p < j: + if p < 0 or (res and res[-1] and com_lines[p] == res[-1]): + break + if com_lines[p].strip() and (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == '"""' and com_lines[p].strip()[:3] == '"""') or (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == "'''" and com_lines[p].strip()[:3] == "'''"): + up_comments.append(com_lines[p]) + p -= 1 + continue + if com_lines[p].strip() and (com_lines[p].strip()[-3:] == '"""' or com_lines[p].strip()[:3] == '"""' or com_lines[p].strip()[-3:] == "'''" or com_lines[p].strip()[:3] == "'''"): + triple_dot_flag = (triple_dot_flag + 1)%2 + up_comments.append(com_lines[p]) + p -= 1 + continue + if triple_dot_flag: + up_comments.append(com_lines[p]) + p -= 1 + continue + if (com_lines[p].strip()=="") or (com_lines[p].strip() and com_lines[p].strip()[0] == '#' and "省略部分内容" not in com_lines[p]): + up_comments.append(com_lines[p]) + else: + break + p -= 1 + if up_comments: + res.extend(reversed(up_comments)) + if "#" in com_lines[j] and "#" not in line: + in_line_comments = " #" + com_lines[j].split("#")[-1] + res.append(line[:-1]+in_line_comments) + else: + res.append(line) + p = j+1 + else: + res.append(line) + j = p + + write_file(comments_path, "".join(res)) + +# 处理单个文件 +def deal_one_file(model, path, args): + context = read_file(path) + + fname = path.split("/")[-1] + fpath = "/".join(path.split("/")[:-1]) + outfname = fname.split(".")[0]+"_comments."+fname.split(".")[-1] + + comments_path = os.path.join(fpath, outfname) + if (not args.regenerate) and os.path.exists(comments_path): + print("use cache: ", comments_path) + return + + context_line = len(context.split("\n")) + if context_line < MaxLine: + res = gen_code_comments(context, model = model) + elif SplitKey[0] not in context: + context_list = split_context_by_maxline(context) + res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list]) + else: + context_list = split_context_by_splitkey(context) + res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list]) + + write_file(comments_path, res) + merge_code_and_comments(path, comments_path) + +# 处理文件夹 +def deal_folder(model, path, args): + for fl in os.listdir(path): + now_path = os.path.join(path, fl) + if os.path.isfile(now_path): + if (now_path.split(".")[-1] in CodeFileType) and ("_comments" not in now_path): + deal_one_file(model, now_path, args) + elif os.path.isdir(now_path): + deal_folder(model, now_path, args) + else: + print("Please specify a correct path!") + +def transfer(args): + model = QWenChat() + + if os.path.isfile(args.path): + if (args.path.split(".")[-1] in CodeFileType) and ("_comments" not in args.path): + deal_one_file(model, args.path, args) + elif os.path.isdir(args.path): + deal_folder(model, args.path, args) + else: + print("Please specify a correct path!") + +if __name__ == '__main__': + args = parse_args() + print(args) + transfer(args)