From 9139fbdf99a425b97f37a728dcfe302c59d026d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=BC=E6=AC=A3?= Date: Tue, 8 Aug 2023 17:45:41 +0800 Subject: [PATCH] release the evaluation benchmark for tool use; update tool use results to that of the hf version --- LICENSE | 2 +- NOTICE | 2 +- README.md | 7 +- README_CN.md | 7 +- eval/EVALUATION.md | 21 ++- eval/evaluate_plugin.py | 308 +++++++++++++++++++++++++++++++++++++++ examples/react_prompt.md | 2 +- requirements.txt | 2 +- tech_memo.md | 4 +- 9 files changed, 339 insertions(+), 16 deletions(-) create mode 100644 eval/evaluate_plugin.py diff --git a/LICENSE b/LICENSE index 99a43f5..d69279e 100644 --- a/LICENSE +++ b/LICENSE @@ -50,4 +50,4 @@ If you are commercially using the Materials, and your product or service has mor 9. Governing Law and Jurisdiction. a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. - b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement. + b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement. \ No newline at end of file diff --git a/NOTICE b/NOTICE index ca737d3..22c063e 100644 --- a/NOTICE +++ b/NOTICE @@ -49,4 +49,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index d4f82c4..a748d03 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ print(response) # 你好!很高兴为你提供帮助。 # 第二轮对话 2nd dialogue turn -response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history) +response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history) print(response) # 这是一个关于一个年轻人奋斗创业最终取得成功的故事。 # 故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。从小,李明就立下了一个目标:要成为一名成功的企业家。 @@ -237,14 +237,14 @@ We provide a CLI demo example in `cli_demo.py`, which supports streaming output ## Tool Usage -Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In the soon-to-be-released internal evaluation benchmark for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance. +Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance. [](https://) | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | |-------------|------------------------|-----------------------|-----------------------| | GPT-4 | 95% | **0.90** | 15% | | GPT-3.5 | 85% | 0.88 | 75% | -| **Qwen-7B** | **99%** | 0.89 | **8.5%** | +| **Qwen-7B** | **99%** | 0.89 | **9.7%** | For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks. @@ -293,4 +293,3 @@ Researchers and developers are free to use the codes and model weights of both Q ## Contact Us If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com. - diff --git a/README_CN.md b/README_CN.md index 5f9201f..816561a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -110,7 +110,7 @@ print(response) # 你好!很高兴为你提供帮助。 # 第二轮对话 2nd dialogue turn -response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history) +response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history) print(response) # 这是一个关于一个年轻人奋斗创业最终取得成功的故事。 # 故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。从小,李明就立下了一个目标:要成为一名成功的企业家。 @@ -241,13 +241,13 @@ model = AutoModelForCausalLM.from_pretrained( ## 工具调用 -Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。我们在内部的即将开源的评测数据集上测试模型的工具调用能力,并发现Qwen-7B-Chat能够取得稳定的表现。 +Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。在我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力,并发现Qwen-7B-Chat能够取得稳定的表现。 | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | | ------------- | ------------------------- | ------------------------ | ------------------------ | | GPT-4 | 95% | **0.90** | 15% | | GPT-3.5 | 85% | 0.88 | 75% | -| **Qwen-7B** | **99%** | 0.89 | **8.5%** | +| **Qwen-7B** | **99%** | 0.89 | **9.7%** | 我们提供了文档说明如何根据ReAct Prompting的原则写作你的prompt。 @@ -298,4 +298,3 @@ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct ## 联系我们 如果你想给我们的研发团队和产品团队留言,请通过邮件(qianwen_opensource@alibabacloud.com)联系我们。 - diff --git a/eval/EVALUATION.md b/eval/EVALUATION.md index 86dd3ed..44e0af6 100644 --- a/eval/EVALUATION.md +++ b/eval/EVALUATION.md @@ -49,9 +49,9 @@ evaluate_functional_correctness HumanEval_res.jsonl python evaluate_chat_mmlu.py -f HumanEval.jsonl -o HumanEval_res_chat.jsonl evaluate_functional_correctness HumanEval_res_chat.jsonl ``` - + When installing package human-eval, please note its following disclaimer: - + This program exists to run untrusted model-generated code. Users are strongly encouraged not to do so outside of a robust security sandbox. The execution call in execution.py is deliberately commented out to ensure users read this disclaimer before running code in a potentially unsafe manner. See the comment in execution.py for more information and instructions. - GSM8K @@ -64,3 +64,20 @@ python evaluate_gsm8k.py python evaluate_chat_gsm8k.py # zeroshot python evaluate_chat_gsm8k.py --use-fewshot # fewshot ``` + +- PLUGIN + +This script is used to reproduce the results of the ReAct and Hugging Face Agent in the Tool Usage section of the README document. + +```Shell +# Qwen-7B-Chat +mkdir data; +cd data; +wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_positive.jsonl; +wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_negative.jsonl; +cd ..; +pip install json5; +pip install jsonlines; +pip install rouge_score; +python evaluate_plugin.py --eval-react-positive --eval-react-negative --eval-hfagent +``` diff --git a/eval/evaluate_plugin.py b/eval/evaluate_plugin.py new file mode 100644 index 0000000..89974ad --- /dev/null +++ b/eval/evaluate_plugin.py @@ -0,0 +1,308 @@ +import argparse +import json +import os +import pprint + +import json5 +import jsonlines +from rouge_score import rouge_scorer +from tqdm import tqdm +from transformers import Agent, AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig +from transformers.tools.evaluate_agent import evaluate_agent +from transformers.trainer_utils import set_seed + +data_root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'data') + + +def is_callable(response, golden): + return response['action'].strip().lower() == golden['action'].strip( + ).lower() + + +def process_res(response): + # parse response + response += '\n' # fix not-find bug + thought = response[:response.find('Action:')].strip() + action = response[response.find('Action:') + + len('Action:'):response.find('Action Input:')].strip() + action_input = response[response.find('Action Input:') + + len('Action Input:'):response.find('Observation:' + )].strip() + #TODO: This parsing result is incorrect if the response contains multiple Actions. To be fixed in the future. + observation = response[response.find('Observation:') + + len('Observation:'):response.rfind('Thought:' + )].strip() + thought_last = response[response.rfind('Thought:') + + len('Thought:'):response.find('Final Answer:' + )].strip() + final_answer = response[response.find('Final Answer:') + + len('Final Answer:'):].strip() + try: + action_input = json.dumps(json5.loads(action_input), + ensure_ascii=False, + sort_keys=True) + except: + # print("JSON Load Error:", action_input) + pass + res_dict = { + 'thought': thought, + 'action': action, + 'action_input': action_input, + 'observation': observation, + 'thought_last': thought_last, + 'final_answer': final_answer + } + return res_dict + + +class _DummyTokenizer: + def tokenize(self, text: str): + return text.split() + + +def _get_tokenized_string(tokenizer, text_list): + token_ids_list, tokenized_string_list = [], [] + for text in text_list: + assert tokenizer is not None + token_ids = tokenizer.encode(text) + tokens_bytes = tokenizer.convert_ids_to_tokens(token_ids) + tokens = [ + token.decode('utf-8', errors='replace') for token in tokens_bytes + ] + tokenized_string = ' '.join(tokens) + token_ids_list.append(token_ids) + tokenized_string_list.append(tokenized_string) + return token_ids_list, tokenized_string_list + + +def eval_action(job): + response = job['gen'][0] + golden = job['response'] + + if 'Action:' in response: + response, golden = process_res(response), process_res(golden) + if is_callable(response, golden): + return True + return False + + +def eval_action_input(job, tokenizer): + response = job['gen'][0] + golden = job['response'] + response, golden = process_res(response), process_res(golden) + query = job['prompt'] + + job = {} + job['prompt'] = query + job['gen'] = response['action_input'] + job['response'] = golden['action_input'] + + job['_gen_tok'], job['_gen_tok_str'] = _get_tokenized_string( + tokenizer, [response['action_input']]) + job['_reference_tok'], job['_reference_tok_str'] = _get_tokenized_string( + tokenizer, [golden['action_input']]) + + scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], + tokenizer=_DummyTokenizer()) + score = scorer.score(job['_reference_tok_str'][0], job['_gen_tok_str'][0]) + + rouge = score['rougeL'].fmeasure + + return rouge + + +class QWenAgent(Agent): + """ + Agent that uses QWen model and tokenizer to generate code. + + Example: + + ```py + agent = QWenAgent() + agent.run("Draw me a picture of rivers and lakes.") + ``` + """ + def __init__(self, + chat_prompt_template=None, + run_prompt_template=None, + additional_tools=None, + tokenizer=None, + model=None): + if tokenizer and model: + self.tokenizer = tokenizer + self.model = model + else: + checkpoint = 'Qwen/Qwen-7B-Chat' + self.tokenizer = AutoTokenizer.from_pretrained( + checkpoint, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained( + checkpoint, device_map='auto', + trust_remote_code=True).cuda().eval() + self.model.generation_config = GenerationConfig.from_pretrained( + checkpoint, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参 + self.model.generation_config.do_sample = False # greedy + + super().__init__( + chat_prompt_template=chat_prompt_template, + run_prompt_template=run_prompt_template, + additional_tools=additional_tools, + ) + + def generate_one(self, prompt, stop): + # "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。 + prompt = prompt.replace('Human:', + '_HUMAN_:').replace('Assistant:', + '_ASSISTANT_:') + stop = [ + item.replace('Human:', '_HUMAN_:').replace('Assistant:', + '_ASSISTANT_:') + for item in stop + ] + + result, _ = self.model.chat(self.tokenizer, prompt, history=None) + for stop_seq in stop: + if result.endswith(stop_seq): + result = result[:-len(stop_seq)] + + result = result.replace('_HUMAN_:', + 'Human:').replace('_ASSISTANT_:', 'Assistant:') + return result + + +def load_models_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, + trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, + device_map='auto', + trust_remote_code=True, + bf16=True, + use_flash_attn=True).eval() + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True) + model.generation_config.do_sample = False # use greedy decoding + return model, tokenizer + + +def load_jobs(filename): + jobs = [] + with jsonlines.open(os.path.join(data_root_path, filename), + mode='r') as reader: + for job in reader: + jobs.append(job) + return jobs + + +def react_inference(filename, model, tokenizer): + filename_cache = filename + '.cache' + if os.path.exists(os.path.join(data_root_path, filename_cache)): + jobs = load_jobs(filename=filename_cache) + print('Loaded from', filename_cache) + else: + with open(os.path.join(data_root_path, filename_cache), 'w') as f: + jobs = load_jobs(filename=filename) + print('Inference:', filename) + for job in tqdm(jobs): + response, history = model.chat(tokenizer, + job['prompt'], + history=None) + job['gen'] = [response] + f.writelines(json.dumps(job, ensure_ascii=False) + '\n') + print(filename_cache, 'is saved.') + return jobs + + +def main(args): + print('loading model weights') + if args.checkpoint_path is not None: + model, tokenizer = load_models_tokenizer(args) + else: + model, tokenizer = None, None + print('model loaded') + + result = {} + # eval react positive + if args.eval_react_positive: + print('eval react positive ...') + acc_count = 0 + rouge_mean = 0 + jobs = react_inference(filename=args.eval_react_positive_filename, + model=model, + tokenizer=tokenizer) + for job in jobs: + if eval_action(job): + acc_count += 1 + rouge = eval_action_input(job, tokenizer) + rouge_mean += (rouge / len(jobs)) + + scores = { + 'action_right_rate': acc_count / len(jobs), + 'action_input_rouge': rouge_mean, + } + + result.update({'react_positive': scores}) + + # eval react negative + if args.eval_react_negative: + print('eval react negative ...') + bad_count = 0 + jobs = react_inference(filename=args.eval_react_negative_filename, + model=model, + tokenizer=tokenizer) + for job in jobs: + if '\nAction:' in job['gen'][0]: + bad_count += 1 + scores = {'bad_rate': bad_count / len(jobs)} + result.update({'react_negative': scores}) + + # eval hfagent + if args.eval_hfagent: + print('eval hfagent ...') + agent = QWenAgent(model=model, tokenizer=tokenizer) + scores = evaluate_agent(agent, verbose=False, return_errors=False) + result.update({'hfagent': scores}) + + pp = pprint.PrettyPrinter(indent=4) + pp.pprint(result) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Test HF checkpoint.') + parser.add_argument('-c', + '--checkpoint-path', + type=str, + help='Checkpoint path', + default='Qwen/Qwen-7B-Chat') + parser.add_argument('-s', + '--seed', + type=int, + default=1234, + help='Random seed') + """Provide extra arguments required for tasks.""" + group = parser.add_argument_group(title='Evaluation options') + group.add_argument('--eval-react-positive', + action='store_true', + default=False, + help='Eval react positive.') + group.add_argument('--eval-react-positive-filename', + type=str, + default='exam_plugin_v1_react_positive.jsonl', + help='Eval react positive filename.') + group.add_argument('--eval-react-negative', + action='store_true', + default=False, + help='Eval react negative.') + group.add_argument('--eval-react-negative-filename', + type=str, + default='exam_plugin_v1_react_negative.jsonl', + help='Eval react negative filename.') + group.add_argument('--eval-hfagent', + action='store_true', + default=False, + help='Eval hfagent.') + + args = parser.parse_args() + set_seed(args.seed) + + main(args) diff --git a/examples/react_prompt.md b/examples/react_prompt.md index 3643171..80e4722 100644 --- a/examples/react_prompt.md +++ b/examples/react_prompt.md @@ -242,4 +242,4 @@ def parse_latest_plugin_call(text: str) -> Tuple[str, str]: return '', '' ``` -此外,如果输出的 Action Input 内容是一段表示 JSON 对象的文本,我们建议使用 `json5` 包的 `json5.loads(...)` 方法加载。 +此外,如果输出的 Action Input 内容是一段表示 JSON 对象的文本,我们建议使用 `json5` 包的 `json5.loads(...)` 方法加载。 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f0e4a66..11ddb14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ tiktoken einops transformers_stream_generator==0.0.4 bitsandbytes -scipy +scipy \ No newline at end of file diff --git a/tech_memo.md b/tech_memo.md index b10bc14..8c8d733 100644 --- a/tech_memo.md +++ b/tech_memo.md @@ -311,13 +311,13 @@ LLMs have shown capability in coordinating multiple external systems to achieve Qwen supports calling plugins/tools/APIs through [ReAct Prompting](https://arxiv.org/abs/2210.03629). ReAct is also one of the main approaches used by the [LangChain](https://python.langchain.com/) framework. For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). -In the soon-to-be-released evaluation benchmark for assessing tool usage capabilities, Qwen's performance is as follows: +In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, Qwen's performance is as follows: | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ | | :---------- | --------------------------: | -------------------------: | -------------------------: | | GPT-4 | 95% | **0.90** | 15.0% | | GPT-3.5 | 85% | 0.88 | 75.0% | -| **Qwen-7B** | **99%** | 0.89 | **8.5%** | +| **Qwen-7B** | **99%** | 0.89 | **9.7%** | > The plugins that appear in the evaluation set do not appear in the training set of Qwen. > This benchmark evaluates the accuracy of the model in selecting the correct plugin from multiple candidate plugins, the rationality of the parameters passed into the plugin, and the false positive rate.