You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
7.5 KiB
Python
227 lines
7.5 KiB
Python
1 year ago
|
import argparse
|
||
|
import base64
|
||
|
import collections
|
||
|
import logging
|
||
|
import unicodedata
|
||
|
from pathlib import Path
|
||
|
|
||
|
import regex as re
|
||
|
from tqdm.contrib.logging import tqdm_logging_redirect
|
||
|
|
||
|
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
logging.basicConfig(
|
||
|
level=logging.DEBUG, format="[%(asctime)s] %(levelname)s - %(message)s"
|
||
|
)
|
||
|
|
||
|
|
||
|
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> "dict[bytes, int]":
|
||
|
contents = open(tiktoken_bpe_file, "rb").read()
|
||
|
return {
|
||
|
base64.b64decode(token): int(rank)
|
||
|
for token, rank in (line.split() for line in contents.splitlines() if line)
|
||
|
}
|
||
|
|
||
|
|
||
|
def dump_tiktoken_bpe(bpe_ranks: "dict[bytes, int]", tiktoken_bpe_file: str) -> None:
|
||
|
with open(tiktoken_bpe_file, "wb") as f:
|
||
|
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
|
||
|
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
|
||
|
|
||
|
|
||
|
def bytes_to_pieces(the_bytes: bytes) -> "tuple[bytes]":
|
||
|
return tuple(bytes([byte]) for byte in the_bytes)
|
||
|
|
||
|
|
||
|
def get_pairs(pieces: "tuple[bytes]") -> "set[tuple[bytes, bytes]]":
|
||
|
return set(zip(pieces[:-1], pieces[1:]))
|
||
|
|
||
|
|
||
|
def get_stats(
|
||
|
vocab: "dict[tuple[bytes, ...], int]",
|
||
|
) -> "dict[tuple[bytes, bytes], int]":
|
||
|
pairs = collections.defaultdict(int)
|
||
|
for word, freq in vocab.items():
|
||
|
for i in range(len(word) - 1):
|
||
|
pairs[(word[i], word[i + 1])] += freq
|
||
|
return pairs
|
||
|
|
||
|
|
||
|
def merge_vocab(
|
||
|
pair: "tuple[bytes, bytes]", vocab: "dict[tuple[bytes, ...], int]"
|
||
|
) -> "dict[tuple[bytes, ...], int]":
|
||
|
return {apply_bp(pieces, pair): freq for pieces, freq in vocab.items()}
|
||
|
|
||
|
|
||
|
def apply_bp(
|
||
|
pieces: "tuple[bytes, ...]", pair: "tuple[bytes, bytes]"
|
||
|
) -> "tuple[bytes, ...]":
|
||
|
new_pieces = []
|
||
|
first, second = pair
|
||
|
i = 0
|
||
|
while i < len(pieces):
|
||
|
try:
|
||
|
j = pieces.index(first, i)
|
||
|
new_pieces.extend(pieces[i:j])
|
||
|
i = j
|
||
|
except:
|
||
|
new_pieces.extend(pieces[i:])
|
||
|
break
|
||
|
|
||
|
if pieces[i] == first and i < len(pieces) - 1 and pieces[i + 1] == second:
|
||
|
new_pieces.append(first + second)
|
||
|
i += 2
|
||
|
else:
|
||
|
new_pieces.append(pieces[i])
|
||
|
i += 1
|
||
|
|
||
|
return tuple(new_pieces)
|
||
|
|
||
|
|
||
|
def bpe(word: bytes, merges: "dict[bytes,int]") -> "tuple[bytes, ...]":
|
||
|
pieces = bytes_to_pieces(word)
|
||
|
while len(pieces) > 1:
|
||
|
pairs = get_pairs(pieces)
|
||
|
pair = min(pairs, key=lambda pair: merges.get(pair[0] + pair[1], float("inf")))
|
||
|
|
||
|
if pair[0] + pair[1] not in merges:
|
||
|
break
|
||
|
pieces = apply_bp(pieces, pair)
|
||
|
# logger.debug(f"{[(p, p.decode('utf8', errors='replace')) for p in pieces]} {pair} {pieces}")
|
||
|
return pieces
|
||
|
|
||
|
|
||
|
def best_pair_sort_key(
|
||
|
item: "tuple[dict[bytes, bytes], int]",
|
||
|
) -> "tuple[int, int, int, str, bytes]":
|
||
|
# prefer to use the highest frequency or shortest length or lexi sort, sligtly slower
|
||
|
pair, freq = item
|
||
|
pair_bytes = pair[0] + pair[1]
|
||
|
pair_byte_length = len(pair_bytes)
|
||
|
pair_str = pair_bytes.decode("utf-8", errors="replace")
|
||
|
pair_str_length = len(pair_str)
|
||
|
return -freq, pair_str_length, pair_byte_length, pair_str, pair_bytes
|
||
|
|
||
|
|
||
|
def learn_bpe(
|
||
|
freqs: "dict[str,int]", existing: "dict[bytes, int]"
|
||
|
) -> "tuple[bytes, bytes]":
|
||
|
vocab = {bpe(k.encode("utf-8"), existing): v for k, v in freqs.items()}
|
||
|
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
|
||
|
new_merges = []
|
||
|
with tqdm_logging_redirect() as bar:
|
||
|
while vocab:
|
||
|
pairs = get_stats(vocab)
|
||
|
|
||
|
best, freq = min(pairs.items(), key=best_pair_sort_key)
|
||
|
|
||
|
logger.debug(
|
||
|
f'{best} ({(best[0]+best[1]).decode("utf-8", errors="replace")}) is selected as the next merge with freq {freq}'
|
||
|
)
|
||
|
new_merges.append(best)
|
||
|
|
||
|
vocab = merge_vocab(best, vocab)
|
||
|
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
|
||
|
bar.update()
|
||
|
|
||
|
return new_merges
|
||
|
|
||
|
|
||
|
def load_expand_vocab(path: Path) -> "dict[str, int]":
|
||
|
freqs = {}
|
||
|
with open(path, "r", encoding="utf8") as fin:
|
||
|
for line in fin:
|
||
|
if not line.strip():
|
||
|
continue
|
||
|
word, freq = line.strip().split("\t")
|
||
|
word = unicodedata.normalize("NFC", word)
|
||
|
parts = re.findall(PAT_STR, word)
|
||
|
if len(parts) > 1:
|
||
|
logger.warning(
|
||
|
f"{word} would be pre-tokenized to {parts}, and thus cannot be added to vocabulary"
|
||
|
)
|
||
|
continue
|
||
|
try:
|
||
|
freq = int(freq)
|
||
|
except ValueError as _:
|
||
|
freq = 1
|
||
|
if word in freqs:
|
||
|
logger.warning(
|
||
|
f"{word} is repeated, the frequency is increased by this much"
|
||
|
)
|
||
|
freqs[word] += freq
|
||
|
else:
|
||
|
freqs[word] = freq
|
||
|
return freqs
|
||
|
|
||
|
|
||
|
def make_new_merges_by_bpe(
|
||
|
input_path: Path, output_path: Path, expand_path: Path, start_id: int
|
||
|
) -> None:
|
||
|
mergeable_ranks = load_tiktoken_bpe(input_path)
|
||
|
|
||
|
if not start_id or start_id == -1:
|
||
|
start_id = len(mergeable_ranks)
|
||
|
elif start_id < len(mergeable_ranks):
|
||
|
logger.warning(
|
||
|
f"start_id {start_id} is too small, existing merges will be overridden, DONOT DO THIS. changed to {len(mergeable_ranks)}"
|
||
|
)
|
||
|
start_id = len(mergeable_ranks)
|
||
|
else:
|
||
|
start_id = start_id
|
||
|
|
||
|
expand_vocab_freqs = load_expand_vocab(expand_path)
|
||
|
for word in list(expand_vocab_freqs):
|
||
|
token = word.encode("utf-8")
|
||
|
if token in mergeable_ranks:
|
||
|
logger.warning(f"word {word} is already a token {token}, skipping")
|
||
|
del expand_vocab_freqs[word]
|
||
|
|
||
|
logger.info(f"number of existing merges: {len(mergeable_ranks)}")
|
||
|
logger.info(f"number of words for expanding: {len(expand_vocab_freqs)}")
|
||
|
|
||
|
new_merges = learn_bpe(expand_vocab_freqs, mergeable_ranks)
|
||
|
logger.info(f"number of newly learned merges: {len(new_merges)}")
|
||
|
|
||
|
extra_merges = {p[0] + p[1]: i for i, p in enumerate(new_merges, start=start_id)}
|
||
|
|
||
|
dump_tiktoken_bpe(extra_merges, output_path)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||
|
)
|
||
|
parser.add_argument("input_path", type=str, help="Path for input tiktoken file")
|
||
|
parser.add_argument(
|
||
|
"output_path",
|
||
|
type=str,
|
||
|
help="Path for output tiktoken file, containing only the new merges",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"vocab_path",
|
||
|
type=str,
|
||
|
help="Path for words needed adding, each line is a word and its frequency separated by \\t",
|
||
|
)
|
||
|
# if the extended vocabulary is for fine-tuning, you better set those correctly (the default is for qwen.tiktoken)
|
||
|
# if the extended vocabulary is for pretraining from the start, no need
|
||
|
parser.add_argument(
|
||
|
"--start_id",
|
||
|
type=int,
|
||
|
default=151851,
|
||
|
help="The start id for new merges. For Qwen tokenizer, this should be 151851 (skipping the existing special tokens)",
|
||
|
)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
make_new_merges_by_bpe(
|
||
|
args.input_path, args.output_path, args.vocab_path, args.start_id
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|