import argparse
import base64
import collections
import logging
import unicodedata
from pathlib import Path

import regex as re
from tqdm.contrib.logging import tqdm_logging_redirect

PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""


logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG, format="[%(asctime)s] %(levelname)s - %(message)s"
)


def load_tiktoken_bpe(tiktoken_bpe_file: str) -> "dict[bytes, int]":
    contents = open(tiktoken_bpe_file, "rb").read()
    return {
        base64.b64decode(token): int(rank)
        for token, rank in (line.split() for line in contents.splitlines() if line)
    }


def dump_tiktoken_bpe(bpe_ranks: "dict[bytes, int]", tiktoken_bpe_file: str) -> None:
    with open(tiktoken_bpe_file, "wb") as f:
        for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
            f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def bytes_to_pieces(the_bytes: bytes) -> "tuple[bytes]":
    return tuple(bytes([byte]) for byte in the_bytes)


def get_pairs(pieces: "tuple[bytes]") -> "set[tuple[bytes, bytes]]":
    return set(zip(pieces[:-1], pieces[1:]))


def get_stats(
    vocab: "dict[tuple[bytes, ...], int]",
) -> "dict[tuple[bytes, bytes], int]":
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        for i in range(len(word) - 1):
            pairs[(word[i], word[i + 1])] += freq
    return pairs


def merge_vocab(
    pair: "tuple[bytes, bytes]", vocab: "dict[tuple[bytes, ...], int]"
) -> "dict[tuple[bytes, ...], int]":
    return {apply_bp(pieces, pair): freq for pieces, freq in vocab.items()}


def apply_bp(
    pieces: "tuple[bytes, ...]", pair: "tuple[bytes, bytes]"
) -> "tuple[bytes, ...]":
    new_pieces = []
    first, second = pair
    i = 0
    while i < len(pieces):
        try:
            j = pieces.index(first, i)
            new_pieces.extend(pieces[i:j])
            i = j
        except:
            new_pieces.extend(pieces[i:])
            break

        if pieces[i] == first and i < len(pieces) - 1 and pieces[i + 1] == second:
            new_pieces.append(first + second)
            i += 2
        else:
            new_pieces.append(pieces[i])
            i += 1

    return tuple(new_pieces)


def bpe(word: bytes, merges: "dict[bytes,int]") -> "tuple[bytes, ...]":
    pieces = bytes_to_pieces(word)
    while len(pieces) > 1:
        pairs = get_pairs(pieces)
        pair = min(pairs, key=lambda pair: merges.get(pair[0] + pair[1], float("inf")))

        if pair[0] + pair[1] not in merges:
            break
        pieces = apply_bp(pieces, pair)
        # logger.debug(f"{[(p, p.decode('utf8', errors='replace')) for p in pieces]} {pair} {pieces}")
    return pieces


def best_pair_sort_key(
    item: "tuple[dict[bytes, bytes], int]",
) -> "tuple[int, int, int, str, bytes]":
    # prefer to use the highest frequency or shortest length or lexi sort, sligtly slower
    pair, freq = item
    pair_bytes = pair[0] + pair[1]
    pair_byte_length = len(pair_bytes)
    pair_str = pair_bytes.decode("utf-8", errors="replace")
    pair_str_length = len(pair_str)
    return -freq, pair_str_length, pair_byte_length, pair_str, pair_bytes


def learn_bpe(
    freqs: "dict[str,int]", existing: "dict[bytes, int]"
) -> "tuple[bytes, bytes]":
    vocab = {bpe(k.encode("utf-8"), existing): v for k, v in freqs.items()}
    vocab = {key: value for key, value in vocab.items() if len(key) > 1}
    new_merges = []
    with tqdm_logging_redirect() as bar:
        while vocab:
            pairs = get_stats(vocab)

            best, freq = min(pairs.items(), key=best_pair_sort_key)

            logger.debug(
                f'{best} ({(best[0]+best[1]).decode("utf-8", errors="replace")}) is selected as the next merge with freq {freq}'
            )
            new_merges.append(best)

            vocab = merge_vocab(best, vocab)
            vocab = {key: value for key, value in vocab.items() if len(key) > 1}
            bar.update()

    return new_merges


def load_expand_vocab(path: Path) -> "dict[str, int]":
    freqs = {}
    with open(path, "r", encoding="utf8") as fin:
        for line in fin:
            if not line.strip():
                continue
            word, freq = line.strip().split("\t")
            word = unicodedata.normalize("NFC", word)
            parts = re.findall(PAT_STR, word)
            if len(parts) > 1:
                logger.warning(
                    f"{word} would be pre-tokenized to {parts}, and thus cannot be added to vocabulary"
                )
                continue
            try:
                freq = int(freq)
            except ValueError as _:
                freq = 1
            if word in freqs:
                logger.warning(
                    f"{word} is repeated, the frequency is increased by this much"
                )
                freqs[word] += freq
            else:
                freqs[word] = freq
    return freqs


def make_new_merges_by_bpe(
    input_path: Path, output_path: Path, expand_path: Path, start_id: int
) -> None:
    mergeable_ranks = load_tiktoken_bpe(input_path)

    if not start_id or start_id == -1:
        start_id = len(mergeable_ranks)
    elif start_id < len(mergeable_ranks):
        logger.warning(
            f"start_id {start_id} is too small, existing merges will be overridden, DONOT DO THIS. changed to {len(mergeable_ranks)}"
        )
        start_id = len(mergeable_ranks)
    else:
        start_id = start_id

    expand_vocab_freqs = load_expand_vocab(expand_path)
    for word in list(expand_vocab_freqs):
        token = word.encode("utf-8")
        if token in mergeable_ranks:
            logger.warning(f"word {word} is already a token {token}, skipping")
            del expand_vocab_freqs[word]

    logger.info(f"number of existing merges: {len(mergeable_ranks)}")
    logger.info(f"number of words for expanding: {len(expand_vocab_freqs)}")

    new_merges = learn_bpe(expand_vocab_freqs, mergeable_ranks)
    logger.info(f"number of newly learned merges: {len(new_merges)}")

    extra_merges = {p[0] + p[1]: i for i, p in enumerate(new_merges, start=start_id)}

    dump_tiktoken_bpe(extra_merges, output_path)


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("input_path", type=str, help="Path for input tiktoken file")
    parser.add_argument(
        "output_path",
        type=str,
        help="Path for output tiktoken file, containing only the new merges",
    )
    parser.add_argument(
        "vocab_path",
        type=str,
        help="Path for words needed adding, each line is a word and its frequency separated by \\t",
    )
    # if the extended vocabulary is for fine-tuning, you better set those correctly (the default is for qwen.tiktoken)
    # if the extended vocabulary is for pretraining from the start, no need
    parser.add_argument(
        "--start_id",
        type=int,
        default=151851,
        help="The start id for new merges. For Qwen tokenizer, this should be 151851 (skipping the existing special tokens)",
    )

    args = parser.parse_args()

    make_new_merges_by_bpe(
        args.input_path, args.output_path, args.vocab_path, args.start_id
    )


if __name__ == "__main__":
    main()