update cache GC in demo and add vocab expansion example

main
yangapku 1 year ago
parent 343017c4ce
commit cbfaada8de

@ -11,6 +11,7 @@ import platform
import shutil
from copy import deepcopy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.trainer_utils import set_seed
@ -64,6 +65,13 @@ def _load_model_tokenizer(args):
return model, tokenizer, config
def _gc():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _clear_screen():
if platform.system() == "Windows":
os.system("cls")
@ -129,10 +137,12 @@ def main():
elif command in ['clear', 'cl']:
_clear_screen()
print(_WELCOME_MSG)
_gc()
continue
elif command in ['clear-history', 'clh']:
print(f'[INFO] All {len(history)} history cleared')
history.clear()
_gc()
continue
elif command in ['help', 'h']:
print(_HELP_MSG)

@ -0,0 +1,226 @@
import argparse
import base64
import collections
import logging
import unicodedata
from pathlib import Path
import regex as re
from tqdm.contrib.logging import tqdm_logging_redirect
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.DEBUG, format="[%(asctime)s] %(levelname)s - %(message)s"
)
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> "dict[bytes, int]":
contents = open(tiktoken_bpe_file, "rb").read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def dump_tiktoken_bpe(bpe_ranks: "dict[bytes, int]", tiktoken_bpe_file: str) -> None:
with open(tiktoken_bpe_file, "wb") as f:
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
def bytes_to_pieces(the_bytes: bytes) -> "tuple[bytes]":
return tuple(bytes([byte]) for byte in the_bytes)
def get_pairs(pieces: "tuple[bytes]") -> "set[tuple[bytes, bytes]]":
return set(zip(pieces[:-1], pieces[1:]))
def get_stats(
vocab: "dict[tuple[bytes, ...], int]",
) -> "dict[tuple[bytes, bytes], int]":
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
for i in range(len(word) - 1):
pairs[(word[i], word[i + 1])] += freq
return pairs
def merge_vocab(
pair: "tuple[bytes, bytes]", vocab: "dict[tuple[bytes, ...], int]"
) -> "dict[tuple[bytes, ...], int]":
return {apply_bp(pieces, pair): freq for pieces, freq in vocab.items()}
def apply_bp(
pieces: "tuple[bytes, ...]", pair: "tuple[bytes, bytes]"
) -> "tuple[bytes, ...]":
new_pieces = []
first, second = pair
i = 0
while i < len(pieces):
try:
j = pieces.index(first, i)
new_pieces.extend(pieces[i:j])
i = j
except:
new_pieces.extend(pieces[i:])
break
if pieces[i] == first and i < len(pieces) - 1 and pieces[i + 1] == second:
new_pieces.append(first + second)
i += 2
else:
new_pieces.append(pieces[i])
i += 1
return tuple(new_pieces)
def bpe(word: bytes, merges: "dict[bytes,int]") -> "tuple[bytes, ...]":
pieces = bytes_to_pieces(word)
while len(pieces) > 1:
pairs = get_pairs(pieces)
pair = min(pairs, key=lambda pair: merges.get(pair[0] + pair[1], float("inf")))
if pair[0] + pair[1] not in merges:
break
pieces = apply_bp(pieces, pair)
# logger.debug(f"{[(p, p.decode('utf8', errors='replace')) for p in pieces]} {pair} {pieces}")
return pieces
def best_pair_sort_key(
item: "tuple[dict[bytes, bytes], int]",
) -> "tuple[int, int, int, str, bytes]":
# prefer to use the highest frequency or shortest length or lexi sort, sligtly slower
pair, freq = item
pair_bytes = pair[0] + pair[1]
pair_byte_length = len(pair_bytes)
pair_str = pair_bytes.decode("utf-8", errors="replace")
pair_str_length = len(pair_str)
return -freq, pair_str_length, pair_byte_length, pair_str, pair_bytes
def learn_bpe(
freqs: "dict[str,int]", existing: "dict[bytes, int]"
) -> "tuple[bytes, bytes]":
vocab = {bpe(k.encode("utf-8"), existing): v for k, v in freqs.items()}
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
new_merges = []
with tqdm_logging_redirect() as bar:
while vocab:
pairs = get_stats(vocab)
best, freq = min(pairs.items(), key=best_pair_sort_key)
logger.debug(
f'{best} ({(best[0]+best[1]).decode("utf-8", errors="replace")}) is selected as the next merge with freq {freq}'
)
new_merges.append(best)
vocab = merge_vocab(best, vocab)
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
bar.update()
return new_merges
def load_expand_vocab(path: Path) -> "dict[str, int]":
freqs = {}
with open(path, "r", encoding="utf8") as fin:
for line in fin:
if not line.strip():
continue
word, freq = line.strip().split("\t")
word = unicodedata.normalize("NFC", word)
parts = re.findall(PAT_STR, word)
if len(parts) > 1:
logger.warning(
f"{word} would be pre-tokenized to {parts}, and thus cannot be added to vocabulary"
)
continue
try:
freq = int(freq)
except ValueError as _:
freq = 1
if word in freqs:
logger.warning(
f"{word} is repeated, the frequency is increased by this much"
)
freqs[word] += freq
else:
freqs[word] = freq
return freqs
def make_new_merges_by_bpe(
input_path: Path, output_path: Path, expand_path: Path, start_id: int
) -> None:
mergeable_ranks = load_tiktoken_bpe(input_path)
if not start_id or start_id == -1:
start_id = len(mergeable_ranks)
elif start_id < len(mergeable_ranks):
logger.warning(
f"start_id {start_id} is too small, existing merges will be overridden, DONOT DO THIS. changed to {len(mergeable_ranks)}"
)
start_id = len(mergeable_ranks)
else:
start_id = start_id
expand_vocab_freqs = load_expand_vocab(expand_path)
for word in list(expand_vocab_freqs):
token = word.encode("utf-8")
if token in mergeable_ranks:
logger.warning(f"word {word} is already a token {token}, skipping")
del expand_vocab_freqs[word]
logger.info(f"number of existing merges: {len(mergeable_ranks)}")
logger.info(f"number of words for expanding: {len(expand_vocab_freqs)}")
new_merges = learn_bpe(expand_vocab_freqs, mergeable_ranks)
logger.info(f"number of newly learned merges: {len(new_merges)}")
extra_merges = {p[0] + p[1]: i for i, p in enumerate(new_merges, start=start_id)}
dump_tiktoken_bpe(extra_merges, output_path)
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("input_path", type=str, help="Path for input tiktoken file")
parser.add_argument(
"output_path",
type=str,
help="Path for output tiktoken file, containing only the new merges",
)
parser.add_argument(
"vocab_path",
type=str,
help="Path for words needed adding, each line is a word and its frequency separated by \\t",
)
# if the extended vocabulary is for fine-tuning, you better set those correctly (the default is for qwen.tiktoken)
# if the extended vocabulary is for pretraining from the start, no need
parser.add_argument(
"--start_id",
type=int,
default=151851,
help="The start id for new merges. For Qwen tokenizer, this should be 151851 (skipping the existing special tokens)",
)
args = parser.parse_args()
make_new_merges_by_bpe(
args.input_path, args.output_path, args.vocab_path, args.start_id
)
if __name__ == "__main__":
main()

@ -0,0 +1,6 @@
5LiA5Y+q54yr 151851
5Y+q54yr 151852
5piv5LiA5Y+q54yr 151853
5oiR5piv5LiA5Y+q54yr 151854
5L2g5piv5LiA5Y+q54yr 151855
5LuW5piv5LiA5Y+q54yr 151856

@ -0,0 +1,6 @@
我是一只猫 20
你是一只猫 10
他是一只猫 5
一只 200
一只猫 100
夸张的 比喻手法 20

@ -9,8 +9,6 @@
"name": "stderr",
"output_type": "stream",
"text": [
"TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
]
}
@ -414,6 +412,142 @@
"source": [
"tokenizer._convert_token_to_id('<|extra_204|>')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vocabulary Expansion"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [35946, 99639, 91680, 100472], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer(\"我是一只猫\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[99639, 91680, 100472]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.encode(\"是一只猫\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True, extra_vocab_file=\"qwen_extra.tiktoken\")\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"151857"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [151854], 'token_type_ids': [0], 'attention_mask': [1]}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer(\"我是一只猫\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'我是一只猫'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(tokenizer.encode(\"我是一只猫\"))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[151853]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.encode(\"是一只猫\")"
]
}
],
"metadata": {
@ -432,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.12"
},
"orig_nbformat": 4
},

@ -21,12 +21,21 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
def _gc(forced: bool = False):
global args
if args.disable_gc and not forced:
return
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
_gc(forced=True)
app = FastAPI(lifespan=lifespan)
@ -392,6 +401,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
**gen_kwargs
)
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
_gc()
response = trim_stop_words(response, stop_words)
if request.functions:
choice_data = parse_response(response)
@ -453,6 +464,8 @@ async def predict(
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield "[DONE]"
_gc()
def _get_args():
parser = ArgumentParser()
@ -476,6 +489,8 @@ def _get_args():
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
" If you want other computers to access your server, use 0.0.0.0 instead.",
)
parser.add_argument("--disable-gc", action="store_true",
help="Disable GC after each response generated.")
args = parser.parse_args()
return args

@ -125,3 +125,122 @@ The new default is the same as
{'input_ids': [1350, 445, 151643, 899], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
```
## Vocabulary Expansion
> WARNING: Read carefully, be aware of what you are doing, and use at your own risk.
> There are certain caveats regarding how your vocabulary is produced.
The tokenizer of Qwen models are based on BPE and you cannot directly expand the vocabulary by adding words to the vocabulary.
The intermediate merges are needed for tokenization.
Please follow the steps to obtain such information.
1. Prepare a plain text file `qwen_extra_vocab.txt`, where each line contains a token and its frequency separated by `\t`.
An example is given below:
```
我是一只猫 20
你是一只猫 10
他是一只猫 5
一只 200
一只猫 100
夸张的 比喻手法 20
```
The frequencies are needed to compute the BPE.
2. Prepare the base vocabulary file, e.g., `qwen.tiktoken`, and determine the start index for new tokens.
There are 151,643 regular tokens and 208 control tokens in the vocabulary for Qwen models.
For simplicity, the start index can be set as 151,851, which is the default value.
You can, of course, override the many inactive control tokens, but you will need to modify the tokenizer code.
3. Run the following command:
```
python add_merges.py qwen.tiktoken qwen_extra.tiktoken qwen_extra_vocab.txt
```
`add_merges.py` can be found [here](examples/add_merges.py).
It will learn the new merges based on the provided `qwen_extra_vocab.txt`.
The new tokens and their indices will be stored in `qwen_extra.tiktoken`.
Modify the paths as you wish.
It is a pure Python implementation, so please expect it to be slow if you are adding a lot of words.
Please note that not all words can be added due to pre-tokenization.
You will get warnings if you try to add such word:
```
WARNING - 夸张的 比喻手法 would be pre-tokenized to ['夸张的', ' 比喻手法'], and thus cannot be added to vocabulary
WARNING - word 一只 is already a token b'\xe4\xb8\x80\xe5\x8f\xaa', skipping
INFO - number of existing merges: 151643
INFO - number of words for expanding: 4
DEBUG - (b'\xe4\xb8\x80\xe5\x8f\xaa', b'\xe7\x8c\xab') (一只猫) is selected as the next merge with freq 100
DEBUG - (b'\xe5\x8f\xaa', b'\xe7\x8c\xab') (只猫) is selected as the next merge with freq 35
DEBUG - (b'\xe6\x98\xaf\xe4\xb8\x80', b'\xe5\x8f\xaa\xe7\x8c\xab') (是一只猫) is selected as the next merge with freq 35
DEBUG - (b'\xe6\x88\x91', b'\xe6\x98\xaf\xe4\xb8\x80\xe5\x8f\xaa\xe7\x8c\xab') (我是一只猫) is selected as the next merge with freq 20
DEBUG - (b'\xe4\xbd\xa0', b'\xe6\x98\xaf\xe4\xb8\x80\xe5\x8f\xaa\xe7\x8c\xab') (你是一只猫) is selected as the next merge with freq 10
DEBUG - (b'\xe4\xbb\x96', b'\xe6\x98\xaf\xe4\xb8\x80\xe5\x8f\xaa\xe7\x8c\xab') (他是一只猫) is selected as the next merge with freq 5
INFO - number of newly learned merges: 6
```
The `qwen_extra.tiktoken` will contain the following lines:
```
5LiA5Y+q54yr 151851
5Y+q54yr 151852
5piv5LiA5Y+q54yr 151853
5oiR5piv5LiA5Y+q54yr 151854
5L2g5piv5LiA5Y+q54yr 151855
5LuW5piv5LiA5Y+q54yr 151856
```
You may use the file as follows in your code:
``` python
from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True, extra_vocab_file="qwen_extra.tiktoken")
>>> len(tokenizer)
151857
>>> tokenizer("我是一只猫")
{'input_ids': [151854], 'token_type_ids': [0], 'attention_mask': [1]}
```
Note: You need the latest tokenizer code, i.e., after 2013-10-08, to use the `extra_vocab_file` argument.
Otherwise, you need to manually append `qwen.tiktoken` (of which path varies with your configuration) with the content from `qwen_extra.tiktoken`.
Certainly, you will need to finetune the model for the new tokens to work.
### Caveats
The tokenizer of Qwen operates directly on UTF-8 byte sequences, unlike others, e.g., SentencePiece that operates on UTF-8 codepoints/characters and falls back to UTF-8 byte sequences for the unknown (IIRC).
The thing is if the frequencies are computed on limited data, the UTF-8 codepoint boundary may not be correctly recognized.
In theory, it could be a problem for fine-tuned models using the expanded vocabulary with limited data.
For example, it could happen that `b'\x80\xe5'` might be merged first for the UTF-8 byte sequence `b'\xe4\xb8\x80\xe5\x8f\xaa'` of the string `一只`, across the UTF-8 codepoint of `一`(`b'\xe4\xb8\x80'`) and `只` (`b'\xe5\x8f\xaa'`).
Normally, this would work just fine for known words, but for actually unknown words, unusual merges may happen, which may not be well understood for the pre-trained model.
Our advice is that to be safe, you should gather the UTF-8 codepoints from all the words you need to add, and also add them to the file with frequencies higher than the sum of the frequencies of the corresponding words.
But since Qwen has most of the Chinese words, it could be okay to just add the Chinese words alone.
For curious minds, you will also notice that in the given example, `一只` is a token and `只猫` is also learned as a new token.
The reason is that `是一` is also a token in Qwen and has higher merging priority than `一只`, such that the merging path for `是|一|只|猫` is `是一|只|猫 -> 是一|只猫 -> 是一只猫` (omitting the UTF-8 byte merges).
This is the characteristic for plain BPE: it is based solely on distribution, meaning it does not have knowledge of which bytes can form a valid UTF-8 codepoint, character, or meaningful word.
The byproduct is that text may be sub-tokenized differently in different contexts, even for words containing only ASCII characters.
```python
>>> tokenizer.tokenize("Panda")
[b'P', b'anda']
>>> tokenizer.tokenize(" Panda")
[b' Panda']
>>> tokenizer.tokenize("Pandas")
[b'P', b'andas']
>>> tokenizer.tokenize(" Pandas")
[b' Pand', b'as']
```
This simply suggests that those combinations occur more frequently in the data.
If you have vast amount of training data, it should not be a problem.

@ -107,6 +107,13 @@ def _parse_text(text):
return text
def _gc():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _launch_demo(args, model, tokenizer, config):
def predict(_query, _chatbot, _task_history):
@ -138,9 +145,7 @@ def _launch_demo(args, model, tokenizer, config):
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
import gc
gc.collect()
torch.cuda.empty_cache()
_gc()
return _chatbot
with gr.Blocks() as demo:

Loading…
Cancel
Save