import random
import tqdm
import os
import re
import sys
import torch
import numpy as np
import jsonlines
import argparse
import jsonlines
import datasets
from datasets import load_from_disk,load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig


ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

def doc_to_text(doc):
        return fewshot_prompt + "\nQuestion: " + doc["question"] + "\nLet's think step by step\n"

def decode(tokens_list, tokenizer, raw_text_len):
    sents = []
    # print(len(tokens_list))
    for tokens in tokens_list:
        tokens = tokens.cpu().numpy().tolist()
        sent = tokenizer.tokenizer.decode(
            tokens[raw_text_len:])
        sent = sent.split('<|endoftext|>')[0]
        sent = sent.split('\n\n\n')[0]
        sent = sent.split("\n\n")[0]
        sent = sent.split("Question:")[0]
        sents.append(sent)
    return sents

def generate_sample(model, tokenizer, input_txt):
    input_ids = tokenizer.tokenizer.encode(input_txt)
    raw_text_len = len(input_ids)
    context_enc = torch.tensor(
                [input_ids]).to(model.device)
    print(f"Input text: {input_txt}\n")
    outputs = model.generate(context_enc)
    output_text = decode(outputs,tokenizer,raw_text_len)[0]
    print(f"\nOutput text: {output_text}\n")
    return output_text


def extract_answer_hf(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return eval(match_str)
    else:
        return INVALID_ANS

def extract_answer(completion):
    try:
        last_number = re.findall(r'\d+', completion)[-1]
        return eval(last_number)
    except:
        return INVALID_ANS

def is_correct( completion, answer):
    gold = extract_answer_hf(answer)
    assert gold != INVALID_ANS, "No ground truth answer found in the document."
    return extract_answer(completion) == gold

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Test HF checkpoint.')
    parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path", default="Qwen/Qwen-7B")
    parser.add_argument("-f","--sample-input-file", type=str, default=None)
    parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl")

    args = parser.parse_args()

    fewshot_prompt = open("gsm8k_prompt.txt").read()
    if args.sample_input_file is not None:
        dataset = load_from_disk(args.sample_input_file)
    else:
        config = datasets.DownloadConfig(resume_download=True, max_retries=100) 
        dataset = load_dataset("gsm8k", 'main', download_config=config)

    test = dataset["test"]

    print('Loading tokenizer ...')
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)

    print('Loading model ...')
    model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
    model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
    model.generation_config.do_sample = False
    
    f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
    tot_length = test.num_rows
    acc_res = []
    for doc in test:
        context = doc_to_text(doc)
        completion = generate_sample(model, tokenizer, context)
        answer= doc["answer"]
        acc = is_correct(completion, answer)
        doc["completion"]=completion
        doc["acc"]=acc
        f_output.write(doc)
        acc_res.append(acc)
    
    f_output.close()
    print("Acc: ",np.mean(acc_res))