# 运行方式:python auto_comments.py --path 'path of file or folder'
# 脚本功能:使用QWen-7B-Chat为提供的代码文件自动生成注释。(详见auto_comments.md)


import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

MaxLine = 50 # 限制单次处理最大代码行数
SplitKey = ["\ndef "] # 自定义的切分代码标识
CodeFileType = ["py"] # 目前仅测试过对python文件生成注释

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, default='Qwen-7B/eval/evaluate_ceval.py')
    parser.add_argument('--regenerate', action='store_true', default=False) #如果已经生成过注释,默认不会重新生成
    args = parser.parse_args()
    return args

class QWenChat():
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)

        # use bf16
        # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
        # use fp16
        # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
        # use cpu only
        # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
        # use auto mode, automatically select precision based on the device.
        self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
        
        # Specify hyperparameters for generation
        self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
        self.history = None
        
    def chat(self, query, system = ""):

        # use history
        # response, history = self.model.chat(self.tokenizer, query, history=self.history)

        # 默认不使用history
        response, history = self.model.chat(self.tokenizer, query, history=None)
        self.history = history

        return response
# 生成注释
def gen_code_comments(context, model = None, **kwargs):
    prompt = "\n为以上代码生成细致的中文注释,注意使用合适的语法。要求必须在每个函数开头生成一段统一的函数功能注释。\n除了注释,请保证原始代码内容不变。不要返回除了注释和代码以外的其余信息,不要生成额外代码。\n"
    return model.chat(context + prompt)

def read_file(path):
    f = open(path, "r",encoding='utf-8')
    lines = f.readlines()
    return "".join(lines)

def write_file(path, context):
    with open(path,'w') as f:
        f.write(context)

# 如果代码文件过长,可以简单按照最大行数切分代码
def split_context_by_maxline(text):
    lines = text.split("\n")
    lines_len = len(lines)
    res = []
    for i in range(MaxLine, lines_len, MaxLine):
        res.append("\n".join(lines[i-MaxLine:i]))

    if i < lines_len:
        res.append("\n".join(lines[i:]))
    return res

# 如果代码文件过长,可以简单按照函数切分代码
def split_context_by_splitkey(text):
    blocks = text.split(SplitKey[0])
    return [blocks[0]] + [SplitKey[0]+x for x in blocks[1:]]

# merge原始代码和生成的注释,目的是保证原始代码不被更改。这部分可以使用各种不同的策略处理。
def merge_code_and_comments(original_file, comments_path):
    res = []
    ori_f = open(original_file, "r",encoding='utf-8')
    ori_lines = ori_f.readlines()

    com_f = open(comments_path, "r",encoding='utf-8')
    com_lines = com_f.readlines()
    len_com_lines = len(com_lines)
    p = 0
    j = 0
    for i, line in enumerate(ori_lines):
        if line.isspace():
            continue
        if line.strip()[0] == '#':
            res.append(line)
            continue
        while j < len_com_lines and line[:-1] not in com_lines[j]:
            j += 1
        if j < len_com_lines:
            p = j - 1
            up_comments = []
            triple_dot_flag = 0
            while p < j:
                if p < 0 or (res and res[-1] and com_lines[p] == res[-1]):
                    break
                if com_lines[p].strip() and (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == '"""' and com_lines[p].strip()[:3] == '"""') or (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == "'''" and com_lines[p].strip()[:3] == "'''"):
                    up_comments.append(com_lines[p])
                    p -= 1
                    continue
                if com_lines[p].strip() and (com_lines[p].strip()[-3:] == '"""' or com_lines[p].strip()[:3] == '"""' or com_lines[p].strip()[-3:] == "'''" or com_lines[p].strip()[:3] == "'''"):
                    triple_dot_flag = (triple_dot_flag + 1)%2
                    up_comments.append(com_lines[p])
                    p -= 1
                    continue
                if triple_dot_flag:
                    up_comments.append(com_lines[p])
                    p -= 1
                    continue
                if (com_lines[p].strip()=="") or (com_lines[p].strip() and com_lines[p].strip()[0] == '#' and "省略部分内容" not in com_lines[p]):
                    up_comments.append(com_lines[p])
                else:
                    break
                p -= 1
            if up_comments:
                res.extend(reversed(up_comments))
            if "#" in com_lines[j] and "#" not in line:
                in_line_comments = "  #" + com_lines[j].split("#")[-1]
                res.append(line[:-1]+in_line_comments)
            else:
                res.append(line)
            p = j+1
        else:
            res.append(line)
            j = p

    write_file(comments_path, "".join(res))

# 处理单个文件
def deal_one_file(model, path, args):
    context = read_file(path)

    fname = path.split("/")[-1]
    fpath = "/".join(path.split("/")[:-1])
    outfname = fname.split(".")[0]+"_comments."+fname.split(".")[-1]

    comments_path = os.path.join(fpath, outfname)
    if (not args.regenerate) and os.path.exists(comments_path):
        print("use cache: ", comments_path)
        return

    context_line = len(context.split("\n"))
    if context_line < MaxLine:
        res = gen_code_comments(context, model = model)
    elif SplitKey[0] not in context:
        context_list = split_context_by_maxline(context)
        res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list])
    else:
        context_list = split_context_by_splitkey(context)
        res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list])

    write_file(comments_path, res)
    merge_code_and_comments(path, comments_path)

# 处理文件夹
def deal_folder(model, path, args):
    for fl in os.listdir(path):
        now_path = os.path.join(path, fl)
        if os.path.isfile(now_path):
            if (now_path.split(".")[-1] in CodeFileType) and ("_comments" not in now_path):
                deal_one_file(model, now_path, args)
        elif os.path.isdir(now_path):
            deal_folder(model, now_path, args)
        else:
            print("Please specify a correct path!")

def transfer(args):
    model = QWenChat()

    if os.path.isfile(args.path):
        if (args.path.split(".")[-1] in CodeFileType) and ("_comments" not in args.path):
            deal_one_file(model, args.path, args)
    elif os.path.isdir(args.path):
        deal_folder(model, args.path, args)
    else:
        print("Please specify a correct path!")

if __name__ == '__main__':
    args = parse_args()
    print(args)
    transfer(args)