You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
7.6 KiB
Python

# 运行方式python auto_comments.py --path 'path of file or folder'
# 脚本功能使用QWen-7B-Chat为提供的代码文件自动生成注释。(详见auto_comments.md)
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
MaxLine = 50 # 限制单次处理最大代码行数
SplitKey = ["\ndef "] # 自定义的切分代码标识
CodeFileType = ["py"] # 目前仅测试过对python文件生成注释
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default='Qwen-7B/eval/evaluate_ceval.py')
parser.add_argument('--regenerate', action='store_true', default=False) #如果已经生成过注释,默认不会重新生成
args = parser.parse_args()
return args
class QWenChat():
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
# use bf16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
# use fp16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
# use cpu only
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
# use auto mode, automatically select precision based on the device.
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
# Specify hyperparameters for generation
self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
self.history = None
def chat(self, query, system = ""):
# use history
# response, history = self.model.chat(self.tokenizer, query, history=self.history)
# 默认不使用history
response, history = self.model.chat(self.tokenizer, query, history=None)
self.history = history
return response
# 生成注释
def gen_code_comments(context, model = None, **kwargs):
prompt = "\n为以上代码生成细致的中文注释,注意使用合适的语法。要求必须在每个函数开头生成一段统一的函数功能注释。\n除了注释,请保证原始代码内容不变。不要返回除了注释和代码以外的其余信息,不要生成额外代码。\n"
return model.chat(context + prompt)
def read_file(path):
f = open(path, "r",encoding='utf-8')
lines = f.readlines()
return "".join(lines)
def write_file(path, context):
with open(path,'w') as f:
f.write(context)
# 如果代码文件过长,可以简单按照最大行数切分代码
def split_context_by_maxline(text):
lines = text.split("\n")
lines_len = len(lines)
res = []
for i in range(MaxLine, lines_len, MaxLine):
res.append("\n".join(lines[i-MaxLine:i]))
if i < lines_len:
res.append("\n".join(lines[i:]))
return res
# 如果代码文件过长,可以简单按照函数切分代码
def split_context_by_splitkey(text):
blocks = text.split(SplitKey[0])
return [blocks[0]] + [SplitKey[0]+x for x in blocks[1:]]
# merge原始代码和生成的注释目的是保证原始代码不被更改。这部分可以使用各种不同的策略处理。
def merge_code_and_comments(original_file, comments_path):
res = []
ori_f = open(original_file, "r",encoding='utf-8')
ori_lines = ori_f.readlines()
com_f = open(comments_path, "r",encoding='utf-8')
com_lines = com_f.readlines()
len_com_lines = len(com_lines)
p = 0
j = 0
for i, line in enumerate(ori_lines):
if line.isspace():
continue
if line.strip()[0] == '#':
res.append(line)
continue
while j < len_com_lines and line[:-1] not in com_lines[j]:
j += 1
if j < len_com_lines:
p = j - 1
up_comments = []
triple_dot_flag = 0
while p < j:
if p < 0 or (res and res[-1] and com_lines[p] == res[-1]):
break
if com_lines[p].strip() and (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == '"""' and com_lines[p].strip()[:3] == '"""') or (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == "'''" and com_lines[p].strip()[:3] == "'''"):
up_comments.append(com_lines[p])
p -= 1
continue
if com_lines[p].strip() and (com_lines[p].strip()[-3:] == '"""' or com_lines[p].strip()[:3] == '"""' or com_lines[p].strip()[-3:] == "'''" or com_lines[p].strip()[:3] == "'''"):
triple_dot_flag = (triple_dot_flag + 1)%2
up_comments.append(com_lines[p])
p -= 1
continue
if triple_dot_flag:
up_comments.append(com_lines[p])
p -= 1
continue
if (com_lines[p].strip()=="") or (com_lines[p].strip() and com_lines[p].strip()[0] == '#' and "省略部分内容" not in com_lines[p]):
up_comments.append(com_lines[p])
else:
break
p -= 1
if up_comments:
res.extend(reversed(up_comments))
if "#" in com_lines[j] and "#" not in line:
in_line_comments = " #" + com_lines[j].split("#")[-1]
res.append(line[:-1]+in_line_comments)
else:
res.append(line)
p = j+1
else:
res.append(line)
j = p
write_file(comments_path, "".join(res))
# 处理单个文件
def deal_one_file(model, path, args):
context = read_file(path)
fname = path.split("/")[-1]
fpath = "/".join(path.split("/")[:-1])
outfname = fname.split(".")[0]+"_comments."+fname.split(".")[-1]
comments_path = os.path.join(fpath, outfname)
if (not args.regenerate) and os.path.exists(comments_path):
print("use cache: ", comments_path)
return
context_line = len(context.split("\n"))
if context_line < MaxLine:
res = gen_code_comments(context, model = model)
elif SplitKey[0] not in context:
context_list = split_context_by_maxline(context)
res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list])
else:
context_list = split_context_by_splitkey(context)
res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list])
write_file(comments_path, res)
merge_code_and_comments(path, comments_path)
# 处理文件夹
def deal_folder(model, path, args):
for fl in os.listdir(path):
now_path = os.path.join(path, fl)
if os.path.isfile(now_path):
if (now_path.split(".")[-1] in CodeFileType) and ("_comments" not in now_path):
deal_one_file(model, now_path, args)
elif os.path.isdir(now_path):
deal_folder(model, now_path, args)
else:
print("Please specify a correct path!")
def transfer(args):
model = QWenChat()
if os.path.isfile(args.path):
if (args.path.split(".")[-1] in CodeFileType) and ("_comments" not in args.path):
deal_one_file(model, args.path, args)
elif os.path.isdir(args.path):
deal_folder(model, args.path, args)
else:
print("Please specify a correct path!")
if __name__ == '__main__':
args = parse_args()
print(args)
transfer(args)