You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
6.7 KiB
Python

from fastllm_pytools import llm;
import torch;
import ctypes;
import numpy as np;
fastllm_data_type_dict = {
"int4": 8,
"int8": 3,
"float16": 7
}
fastllm_weight_type_dict = {
"linear": 1,
"embedding": 2,
"QuantizedLinear": 111
}
def create(model,
tokenizer = None,
pre_prompt = None,
user_role = None,
bot_role = None,
history_sep = None,
dtype = "float16"):
if (dtype not in fastllm_data_type_dict):
print("dtype should in ", list(fastllm_data_type_dict.keys()));
exit(0);
# 0.1 model info
if model.config.model_type == "chatglm" and model.config.transformers_version == "4.30.2":
model.config.model_type = "chatglm3"
modelInfo = model.config.__dict__
if model.generation_config is not None:
modelInfo.update(model.generation_config.__dict__)
if (pre_prompt):
modelInfo["pre_prompt"] = pre_prompt;
if (user_role):
modelInfo["user_role"] = user_role;
if (bot_role):
modelInfo["bot_role"] = bot_role;
if (history_sep):
modelInfo["history_sep"] = history_sep;
if (modelInfo["model_type"] == "baichuan" and hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")):
# Baichuan 2代
modelInfo["use_alibi"] = "1";
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.user_token_id) + "> ") if hasattr(model.generation_config, "user_token_id") else "";
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(model.generation_config.assistant_token_id) + ">") if hasattr(model.generation_config, "assistant_token_id") else "";
modelInfo["history_sep"] = "";
if (modelInfo["model_type"] == "qwen"):
if modelInfo["chat_format"] == "chatml":
modelInfo["im_end_id"] = tokenizer.im_end_id
modelInfo["im_start_id"] = tokenizer.im_start_id
weight_type_dict = {};
module_dict = {};
weight_bits = {};
for key, m in model.named_modules():
if (str(type(m)).find("QuantizedLinear") != -1):
weight_type_dict[key + ".weight"] = "QuantizedLinear";
weight_bits[key + ".weight"] = m.weight_bit_width;
if (isinstance(m, torch.nn.Linear)):
weight_type_dict[key + ".weight"] = "linear";
module_dict[key + ".weight"] = m;
if (isinstance(m, torch.nn.Embedding)):
weight_type_dict[key] = "embedding";
peft_config = {}
active_adapter = ""
if hasattr(model, "peft_config"):
peft_config = model.peft_config
if hasattr(model, "active_adapter") and isinstance(model.active_adapter, str):
# in transformers >= 4.33.0, active_adapter is a funtion in model, ignore it now
active_adapter = model.active_adapter
model = model.cpu();
dict = model.state_dict();
model_type = model.config.__dict__["model_type"];
model = llm.fastllm_lib.create_empty_llm_model(model_type.encode());
for it in modelInfo.keys():
llm.fastllm_lib.add_dict_llm_model(model, str(it).encode(), str(modelInfo[it]).encode());
for adapter_name in peft_config.keys():
adapter_dict = peft_config[adapter_name].__dict__
for it in adapter_dict.keys():
llm.fastllm_lib.add_adapter_dict_llm_model(model, str(adapter_name).encode(), str(it).encode(), str(adapter_dict[it]).encode())
if len(active_adapter) != 0:
llm.fastllm_lib.set_adapter(model, str(active_adapter).encode())
# 1. vocab
if (tokenizer):
if (hasattr(tokenizer, "tokenizer")):
if modelInfo["model_type"] == "qwen":
pass
else:
tokenizer = tokenizer.tokenizer;
if (hasattr(tokenizer, "sp_model")):
piece_size = tokenizer.sp_model.piece_size();
for i in range(piece_size):
llm.fastllm_lib.add_tokenizer_word_llm_model(model, tokenizer.sp_model.id_to_piece(i).encode(),
i, ctypes.c_float(tokenizer.sp_model.get_score(i)));
else:
vocab = tokenizer.get_vocab();
for v in vocab.keys():
if (modelInfo["model_type"] == "moss"):
vv = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v];
llm.fastllm_lib.add_tokenizer_word_llm_model(model, vv, vocab[v], ctypes.c_float(1.0));
elif (modelInfo["model_type"] == "qwen"):
llm.fastllm_lib.add_tokenizer_word_llm_model(model, v, vocab[v], ctypes.c_float(1.0));
else:
llm.fastllm_lib.add_tokenizer_word_llm_model(model, v.encode(), vocab[v], ctypes.c_float(1.0));
tot = 0;
for key in dict:
ori_data_type = 0;
ori_np_data_type = np.float32;
cur_weight_type = 0;
if (key in weight_type_dict and weight_type_dict[key] in fastllm_weight_type_dict):
cur_weight_type = fastllm_weight_type_dict[weight_type_dict[key]];
to_data_type = 0;
if (cur_weight_type == 1):
to_data_type = fastllm_data_type_dict[dtype];
if (to_data_type == 7):
ori_data_type = 7;
ori_np_data_type = np.float16;
elif (cur_weight_type == 2):
# TODO bfloat
to_data_type = 0;
weight_name = key
if peft_config is not None:
weight_name = weight_name.replace('base_model.model.', '')
if (cur_weight_type == 111):
llm.fastllm_lib.add_qlinear_weight_llm_model(model, weight_name.encode(),
len(dict[key].shape),
(ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
weight_bits[key],
dict[key + "_scale"].numpy().astype(np.float32).ctypes.data_as(ctypes.c_void_p),
dict[key].numpy().ctypes.data_as(ctypes.c_void_p));
else:
llm.fastllm_lib.add_weight_llm_model(model, weight_name.encode(),
len(dict[key].shape),
(ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
to_data_type, cur_weight_type, ori_data_type,
dict[key].numpy().astype(ori_np_data_type).ctypes.data_as(ctypes.c_void_p));
tot += 1;
print("convert (", tot, "/", len(dict), end = " )\r");
print("");
llm.fastllm_lib.init_params_llm_model(model);
llm.fastllm_lib.warmup_llm_model(model);
ret = llm.model("", id = model);
return ret;