You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.5 KiB
Python

1 year ago
import random
import tqdm
import os
import sys
import torch
import jsonlines
import argparse
import jsonlines
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
"""
git clone https://github.com/openai/human-eval
$ pip install -e human-eval
evaluate_functional_correctness sample-output-file
"""
def decode(tokens_list, tokenizer, raw_text_len):
sents = []
# print(len(tokens_list))
for tokens in tokens_list:
tokens = tokens.cpu().numpy().tolist()
sent = tokenizer.tokenizer.decode(
tokens[raw_text_len:])
sent = sent.split('<|endoftext|>')[0]
sent = sent.split('\n\n\n')[0]
sent = sent.split("\n\n")[0]
sent = sent.split("def ")[0]
sents.append(sent)
return sents
def generate_sample(model, tokenizer, input_txt):
input_ids = tokenizer.tokenizer.encode(input_txt)
raw_text_len = len(input_ids)
context_enc = torch.tensor([input_ids] ).to(model.device)
print(f"Input text: {input_txt}\n")
outputs = model.generate(context_enc)
output_text = decode(outputs,tokenizer,raw_text_len)[0]
print(f"\nOutput text: \n{output_text}\n")
return output_text
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test HF checkpoint.')
parser.add_argument("-c", "--checkpoint-path", type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl")
args = parser.parse_args()
print('Loading tokenizer ...')
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
print('Loading model ...')
model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
model.generation_config.do_sample = False
f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
f = jsonlines.open(args.sample_input_file)
with f_output as output:
for jobj in tqdm.tqdm(f, desc='task_idx'):
prompt = jobj['prompt']
task_id = jobj['task_id']
gen_sents = generate_sample(model, tokenizer, prompt)
gen_jobjs = {'task_id': task_id, "completion": gen_sents}
output.write(gen_jobjs)
f_output.close()