You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
2.6 KiB
Python

import argparse
1 year ago
import tqdm
import torch
import jsonlines
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
"""
git clone https://github.com/openai/human-eval
$ pip install -e human-eval
evaluate_functional_correctness sample-output-file
"""
1 year ago
def decode(tokens_list, tokenizer, raw_text_len):
sents = []
# print(len(tokens_list))
for tokens in tokens_list:
tokens = tokens.cpu().numpy().tolist()
sent = tokenizer.tokenizer.decode(tokens[raw_text_len:])
sent = sent.split("<|endoftext|>")[0]
sent = sent.split("\n\n\n")[0]
1 year ago
sent = sent.split("\n\n")[0]
sent = sent.split("def ")[0]
sents.append(sent)
return sents
1 year ago
def generate_sample(model, tokenizer, input_txt):
input_ids = tokenizer.tokenizer.encode(input_txt)
raw_text_len = len(input_ids)
context_enc = torch.tensor([input_ids]).to(model.device)
1 year ago
print(f"Input text: {input_txt}\n")
outputs = model.generate(context_enc)
output_text = decode(outputs, tokenizer, raw_text_len)[0]
1 year ago
print(f"\nOutput text: \n{output_text}\n")
return output_text
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
help="Checkpoint path",
default="Qwen/Qwen-7B",
)
parser.add_argument(
"-f",
"--sample-input-file",
type=str,
default=None,
help="data path to HumanEval.jsonl",
)
parser.add_argument(
"-o", "--sample-output-file", type=str, default="HumanEval_res.jsonl"
)
1 year ago
args = parser.parse_args()
print("Loading tokenizer ...")
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
1 year ago
print("Loading model ...")
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path, device_map="auto", trust_remote_code=True
).eval()
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
1 year ago
model.generation_config.do_sample = False
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
1 year ago
f = jsonlines.open(args.sample_input_file)
with f_output as output:
for jobj in tqdm.tqdm(f, desc="task_idx"):
prompt = jobj["prompt"]
task_id = jobj["task_id"]
1 year ago
gen_sents = generate_sample(model, tokenizer, prompt)
gen_jobjs = {"task_id": task_id, "completion": gen_sents}
1 year ago
output.write(gen_jobjs)
f_output.close()