# 运行方式:python auto_comments.py --path 'path of file or folder' # 脚本功能:使用QWen-7B-Chat为提供的代码文件自动生成注释。(详见auto_comments.md) import argparse import os from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig MaxLine = 50 # 限制单次处理最大代码行数 SplitKey = ["\ndef "] # 自定义的切分代码标识 CodeFileType = ["py"] # 目前仅测试过对python文件生成注释 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--path', type=str, default='Qwen-7B/eval/evaluate_ceval.py') parser.add_argument('--regenerate', action='store_true', default=False) #如果已经生成过注释,默认不会重新生成 args = parser.parse_args() return args class QWenChat(): def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) # use bf16 # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval() # use fp16 # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() # use cpu only # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval() # use auto mode, automatically select precision based on the device. self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval() # Specify hyperparameters for generation self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) self.history = None def chat(self, query, system = ""): # use history # response, history = self.model.chat(self.tokenizer, query, history=self.history) # 默认不使用history response, history = self.model.chat(self.tokenizer, query, history=None) self.history = history return response # 生成注释 def gen_code_comments(context, model = None, **kwargs): prompt = "\n为以上代码生成细致的中文注释,注意使用合适的语法。要求必须在每个函数开头生成一段统一的函数功能注释。\n除了注释,请保证原始代码内容不变。不要返回除了注释和代码以外的其余信息,不要生成额外代码。\n" return model.chat(context + prompt) def read_file(path): f = open(path, "r",encoding='utf-8') lines = f.readlines() return "".join(lines) def write_file(path, context): with open(path,'w') as f: f.write(context) # 如果代码文件过长,可以简单按照最大行数切分代码 def split_context_by_maxline(text): lines = text.split("\n") lines_len = len(lines) res = [] for i in range(MaxLine, lines_len, MaxLine): res.append("\n".join(lines[i-MaxLine:i])) if i < lines_len: res.append("\n".join(lines[i:])) return res # 如果代码文件过长,可以简单按照函数切分代码 def split_context_by_splitkey(text): blocks = text.split(SplitKey[0]) return [blocks[0]] + [SplitKey[0]+x for x in blocks[1:]] # merge原始代码和生成的注释,目的是保证原始代码不被更改。这部分可以使用各种不同的策略处理。 def merge_code_and_comments(original_file, comments_path): res = [] ori_f = open(original_file, "r",encoding='utf-8') ori_lines = ori_f.readlines() com_f = open(comments_path, "r",encoding='utf-8') com_lines = com_f.readlines() len_com_lines = len(com_lines) p = 0 j = 0 for i, line in enumerate(ori_lines): if line.isspace(): continue if line.strip()[0] == '#': res.append(line) continue while j < len_com_lines and line[:-1] not in com_lines[j]: j += 1 if j < len_com_lines: p = j - 1 up_comments = [] triple_dot_flag = 0 while p < j: if p < 0 or (res and res[-1] and com_lines[p] == res[-1]): break if com_lines[p].strip() and (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == '"""' and com_lines[p].strip()[:3] == '"""') or (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == "'''" and com_lines[p].strip()[:3] == "'''"): up_comments.append(com_lines[p]) p -= 1 continue if com_lines[p].strip() and (com_lines[p].strip()[-3:] == '"""' or com_lines[p].strip()[:3] == '"""' or com_lines[p].strip()[-3:] == "'''" or com_lines[p].strip()[:3] == "'''"): triple_dot_flag = (triple_dot_flag + 1)%2 up_comments.append(com_lines[p]) p -= 1 continue if triple_dot_flag: up_comments.append(com_lines[p]) p -= 1 continue if (com_lines[p].strip()=="") or (com_lines[p].strip() and com_lines[p].strip()[0] == '#' and "省略部分内容" not in com_lines[p]): up_comments.append(com_lines[p]) else: break p -= 1 if up_comments: res.extend(reversed(up_comments)) if "#" in com_lines[j] and "#" not in line: in_line_comments = " #" + com_lines[j].split("#")[-1] res.append(line[:-1]+in_line_comments) else: res.append(line) p = j+1 else: res.append(line) j = p write_file(comments_path, "".join(res)) # 处理单个文件 def deal_one_file(model, path, args): context = read_file(path) fname = path.split("/")[-1] fpath = "/".join(path.split("/")[:-1]) outfname = fname.split(".")[0]+"_comments."+fname.split(".")[-1] comments_path = os.path.join(fpath, outfname) if (not args.regenerate) and os.path.exists(comments_path): print("use cache: ", comments_path) return context_line = len(context.split("\n")) if context_line < MaxLine: res = gen_code_comments(context, model = model) elif SplitKey[0] not in context: context_list = split_context_by_maxline(context) res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list]) else: context_list = split_context_by_splitkey(context) res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list]) write_file(comments_path, res) merge_code_and_comments(path, comments_path) # 处理文件夹 def deal_folder(model, path, args): for fl in os.listdir(path): now_path = os.path.join(path, fl) if os.path.isfile(now_path): if (now_path.split(".")[-1] in CodeFileType) and ("_comments" not in now_path): deal_one_file(model, now_path, args) elif os.path.isdir(now_path): deal_folder(model, now_path, args) else: print("Please specify a correct path!") def transfer(args): model = QWenChat() if os.path.isfile(args.path): if (args.path.split(".")[-1] in CodeFileType) and ("_comments" not in args.path): deal_one_file(model, args.path, args) elif os.path.isdir(args.path): deal_folder(model, args.path, args) else: print("Please specify a correct path!") if __name__ == '__main__': args = parse_args() print(args) transfer(args)