import argparse
import json
import os
import pprint

import json5
import jsonlines
from rouge_score import rouge_scorer
from tqdm import tqdm
from transformers import Agent, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.tools.evaluate_agent import evaluate_agent
from transformers.trainer_utils import set_seed

data_root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                              'data')


def is_callable(response, golden):
    return response['action'].strip().lower() == golden['action'].strip(
    ).lower()


def process_res(response):
    # parse response
    response += '\n'  # fix not-find bug
    thought = response[:response.find('Action:')].strip()
    action = response[response.find('Action:') +
                      len('Action:'):response.find('Action Input:')].strip()
    action_input = response[response.find('Action Input:') +
                            len('Action Input:'):response.find('Observation:'
                                                               )].strip()
    #TODO: This parsing result is incorrect if the response contains multiple Actions. To be fixed in the future.
    observation = response[response.find('Observation:') +
                           len('Observation:'):response.rfind('Thought:'
                                                              )].strip()
    thought_last = response[response.rfind('Thought:') +
                            len('Thought:'):response.find('Final Answer:'
                                                          )].strip()
    final_answer = response[response.find('Final Answer:') +
                            len('Final Answer:'):].strip()
    try:
        action_input = json.dumps(json5.loads(action_input),
                                  ensure_ascii=False,
                                  sort_keys=True)
    except:
        # print("JSON Load Error:", action_input)
        pass
    res_dict = {
        'thought': thought,
        'action': action,
        'action_input': action_input,
        'observation': observation,
        'thought_last': thought_last,
        'final_answer': final_answer
    }
    return res_dict


class _DummyTokenizer:
    def tokenize(self, text: str):
        return text.split()


def _get_tokenized_string(tokenizer, text_list):
    token_ids_list, tokenized_string_list = [], []
    for text in text_list:
        assert tokenizer is not None
        token_ids = tokenizer.encode(text)
        tokens_bytes = tokenizer.convert_ids_to_tokens(token_ids)
        tokens = [
            token.decode('utf-8', errors='replace') for token in tokens_bytes
        ]
        tokenized_string = ' '.join(tokens)
        token_ids_list.append(token_ids)
        tokenized_string_list.append(tokenized_string)
    return token_ids_list, tokenized_string_list


def eval_action(job):
    response = job['gen'][0]
    golden = job['response']

    if 'Action:' in response:
        response, golden = process_res(response), process_res(golden)
        if is_callable(response, golden):
            return True
    return False


def eval_action_input(job, tokenizer):
    response = job['gen'][0]
    golden = job['response']
    response, golden = process_res(response), process_res(golden)
    query = job['prompt']

    job = {}
    job['prompt'] = query
    job['gen'] = response['action_input']
    job['response'] = golden['action_input']

    job['_gen_tok'], job['_gen_tok_str'] = _get_tokenized_string(
        tokenizer, [response['action_input']])
    job['_reference_tok'], job['_reference_tok_str'] = _get_tokenized_string(
        tokenizer, [golden['action_input']])

    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'],
                                      tokenizer=_DummyTokenizer())
    score = scorer.score(job['_reference_tok_str'][0], job['_gen_tok_str'][0])

    rouge = score['rougeL'].fmeasure

    return rouge


class QWenAgent(Agent):
    """
    Agent that uses QWen model and tokenizer to generate code.

    Example:

    ```py
    agent = QWenAgent()
    agent.run("Draw me a picture of rivers and lakes.")
    ```
    """
    def __init__(self,
                 chat_prompt_template=None,
                 run_prompt_template=None,
                 additional_tools=None,
                 tokenizer=None,
                 model=None):
        if tokenizer and model:
            self.tokenizer = tokenizer
            self.model = model
        else:
            checkpoint = 'Qwen/Qwen-7B-Chat'
            self.tokenizer = AutoTokenizer.from_pretrained(
                checkpoint, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                checkpoint, device_map='auto',
                trust_remote_code=True).cuda().eval()
            self.model.generation_config = GenerationConfig.from_pretrained(
                checkpoint, trust_remote_code=True)  # 可指定不同的生成长度、top_p等相关超参
            self.model.generation_config.do_sample = False  # greedy

        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )

    def generate_one(self, prompt, stop):
        # "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。
        prompt = prompt.replace('Human:',
                                '_HUMAN_:').replace('Assistant:',
                                                    '_ASSISTANT_:')
        stop = [
            item.replace('Human:', '_HUMAN_:').replace('Assistant:',
                                                       '_ASSISTANT_:')
            for item in stop
        ]

        result, _ = self.model.chat(self.tokenizer, prompt, history=None)
        for stop_seq in stop:
            if result.endswith(stop_seq):
                result = result[:-len(stop_seq)]

        result = result.replace('_HUMAN_:',
                                'Human:').replace('_ASSISTANT_:', 'Assistant:')
        return result


def load_models_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path,
                                              trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path,
                                                 device_map='auto',
                                                 trust_remote_code=True,
                                                 bf16=True,
                                                 use_flash_attn=True).eval()
    model.generation_config = GenerationConfig.from_pretrained(
        args.checkpoint_path, trust_remote_code=True)
    model.generation_config.do_sample = False  # use greedy decoding
    return model, tokenizer


def load_jobs(filename):
    jobs = []
    with jsonlines.open(os.path.join(data_root_path, filename),
                        mode='r') as reader:
        for job in reader:
            jobs.append(job)
    return jobs


def react_inference(filename, model, tokenizer):
    filename_cache = filename + '.cache'
    if os.path.exists(os.path.join(data_root_path, filename_cache)):
        jobs = load_jobs(filename=filename_cache)
        print('Loaded from', filename_cache)
    else:
        with open(os.path.join(data_root_path, filename_cache), 'w') as f:
            jobs = load_jobs(filename=filename)
            print('Inference:', filename)
            for job in tqdm(jobs):
                response, history = model.chat(tokenizer,
                                               job['prompt'],
                                               history=None)
                job['gen'] = [response]
                f.writelines(json.dumps(job, ensure_ascii=False) + '\n')
        print(filename_cache, 'is saved.')
    return jobs


def main(args):
    print('loading model weights')
    if args.checkpoint_path is not None:
        model, tokenizer = load_models_tokenizer(args)
    else:
        model, tokenizer = None, None
    print('model loaded')

    result = {}
    # eval react positive
    if args.eval_react_positive:
        print('eval react positive ...')
        acc_count = 0
        rouge_mean = 0
        jobs = react_inference(filename=args.eval_react_positive_filename,
                               model=model,
                               tokenizer=tokenizer)
        for job in jobs:
            if eval_action(job):
                acc_count += 1
            rouge = eval_action_input(job, tokenizer)
            rouge_mean += (rouge / len(jobs))

        scores = {
            'action_right_rate': acc_count / len(jobs),
            'action_input_rouge': rouge_mean,
        }

        result.update({'react_positive': scores})

    # eval react negative
    if args.eval_react_negative:
        print('eval react negative ...')
        bad_count = 0
        jobs = react_inference(filename=args.eval_react_negative_filename,
                               model=model,
                               tokenizer=tokenizer)
        for job in jobs:
            if '\nAction:' in job['gen'][0]:
                bad_count += 1
        scores = {'bad_rate': bad_count / len(jobs)}
        result.update({'react_negative': scores})

    # eval hfagent
    if args.eval_hfagent:
        print('eval hfagent ...')
        agent = QWenAgent(model=model, tokenizer=tokenizer)
        scores = evaluate_agent(agent, verbose=False, return_errors=False)
        result.update({'hfagent': scores})

    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(result)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test HF checkpoint.')
    parser.add_argument('-c',
                        '--checkpoint-path',
                        type=str,
                        help='Checkpoint path',
                        default='Qwen/Qwen-7B-Chat')
    parser.add_argument('-s',
                        '--seed',
                        type=int,
                        default=1234,
                        help='Random seed')
    """Provide extra arguments required for tasks."""
    group = parser.add_argument_group(title='Evaluation options')
    group.add_argument('--eval-react-positive',
                       action='store_true',
                       default=False,
                       help='Eval react positive.')
    group.add_argument('--eval-react-positive-filename',
                       type=str,
                       default='exam_plugin_v1_react_positive.jsonl',
                       help='Eval react positive filename.')
    group.add_argument('--eval-react-negative',
                       action='store_true',
                       default=False,
                       help='Eval react negative.')
    group.add_argument('--eval-react-negative-filename',
                       type=str,
                       default='exam_plugin_v1_react_negative.jsonl',
                       help='Eval react negative filename.')
    group.add_argument('--eval-hfagent',
                       action='store_true',
                       default=False,
                       help='Eval hfagent.')

    args = parser.parse_args()
    set_seed(args.seed)

    main(args)