release the evaluation benchmark for tool use; update tool use results to that of the hf version

main
兼欣 1 year ago
parent fa33db2a26
commit 9139fbdf99

@ -237,14 +237,14 @@ We provide a CLI demo example in `cli_demo.py`, which supports streaming output
## Tool Usage
Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In the soon-to-be-released internal evaluation benchmark for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
[](https://)
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
|-------------|------------------------|-----------------------|-----------------------|
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B** | **99%** | 0.89 | **8.5%** |
| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks.
@ -293,4 +293,3 @@ Researchers and developers are free to use the codes and model weights of both Q
## Contact Us
If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com.

@ -241,13 +241,13 @@ model = AutoModelForCausalLM.from_pretrained(
## 工具调用
Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。我们在内部的即将开源的评测数据集上测试模型的工具调用能力并发现Qwen-7B-Chat能够取得稳定的表现。
Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力并发现Qwen-7B-Chat能够取得稳定的表现。
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
| ------------- | ------------------------- | ------------------------ | ------------------------ |
| GPT-4 | 95% | **0.90** | 15% |
| GPT-3.5 | 85% | 0.88 | 75% |
| **Qwen-7B** | **99%** | 0.89 | **8.5%** |
| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
我们提供了文档说明如何根据ReAct Prompting的原则写作你的prompt。
@ -298,4 +298,3 @@ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct
## 联系我们
如果你想给我们的研发团队和产品团队留言请通过邮件qianwen_opensource@alibabacloud.com联系我们。

@ -64,3 +64,20 @@ python evaluate_gsm8k.py
python evaluate_chat_gsm8k.py # zeroshot
python evaluate_chat_gsm8k.py --use-fewshot # fewshot
```
- PLUGIN
This script is used to reproduce the results of the ReAct and Hugging Face Agent in the Tool Usage section of the README document.
```Shell
# Qwen-7B-Chat
mkdir data;
cd data;
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_positive.jsonl;
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_negative.jsonl;
cd ..;
pip install json5;
pip install jsonlines;
pip install rouge_score;
python evaluate_plugin.py --eval-react-positive --eval-react-negative --eval-hfagent
```

@ -0,0 +1,308 @@
import argparse
import json
import os
import pprint
import json5
import jsonlines
from rouge_score import rouge_scorer
from tqdm import tqdm
from transformers import Agent, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.tools.evaluate_agent import evaluate_agent
from transformers.trainer_utils import set_seed
data_root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'data')
def is_callable(response, golden):
return response['action'].strip().lower() == golden['action'].strip(
).lower()
def process_res(response):
# parse response
response += '\n' # fix not-find bug
thought = response[:response.find('Action:')].strip()
action = response[response.find('Action:') +
len('Action:'):response.find('Action Input:')].strip()
action_input = response[response.find('Action Input:') +
len('Action Input:'):response.find('Observation:'
)].strip()
#TODO: This parsing result is incorrect if the response contains multiple Actions. To be fixed in the future.
observation = response[response.find('Observation:') +
len('Observation:'):response.rfind('Thought:'
)].strip()
thought_last = response[response.rfind('Thought:') +
len('Thought:'):response.find('Final Answer:'
)].strip()
final_answer = response[response.find('Final Answer:') +
len('Final Answer:'):].strip()
try:
action_input = json.dumps(json5.loads(action_input),
ensure_ascii=False,
sort_keys=True)
except:
# print("JSON Load Error:", action_input)
pass
res_dict = {
'thought': thought,
'action': action,
'action_input': action_input,
'observation': observation,
'thought_last': thought_last,
'final_answer': final_answer
}
return res_dict
class _DummyTokenizer:
def tokenize(self, text: str):
return text.split()
def _get_tokenized_string(tokenizer, text_list):
token_ids_list, tokenized_string_list = [], []
for text in text_list:
assert tokenizer is not None
token_ids = tokenizer.encode(text)
tokens_bytes = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [
token.decode('utf-8', errors='replace') for token in tokens_bytes
]
tokenized_string = ' '.join(tokens)
token_ids_list.append(token_ids)
tokenized_string_list.append(tokenized_string)
return token_ids_list, tokenized_string_list
def eval_action(job):
response = job['gen'][0]
golden = job['response']
if 'Action:' in response:
response, golden = process_res(response), process_res(golden)
if is_callable(response, golden):
return True
return False
def eval_action_input(job, tokenizer):
response = job['gen'][0]
golden = job['response']
response, golden = process_res(response), process_res(golden)
query = job['prompt']
job = {}
job['prompt'] = query
job['gen'] = response['action_input']
job['response'] = golden['action_input']
job['_gen_tok'], job['_gen_tok_str'] = _get_tokenized_string(
tokenizer, [response['action_input']])
job['_reference_tok'], job['_reference_tok_str'] = _get_tokenized_string(
tokenizer, [golden['action_input']])
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'],
tokenizer=_DummyTokenizer())
score = scorer.score(job['_reference_tok_str'][0], job['_gen_tok_str'][0])
rouge = score['rougeL'].fmeasure
return rouge
class QWenAgent(Agent):
"""
Agent that uses QWen model and tokenizer to generate code.
Example:
```py
agent = QWenAgent()
agent.run("Draw me a picture of rivers and lakes.")
```
"""
def __init__(self,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
tokenizer=None,
model=None):
if tokenizer and model:
self.tokenizer = tokenizer
self.model = model
else:
checkpoint = 'Qwen/Qwen-7B-Chat'
self.tokenizer = AutoTokenizer.from_pretrained(
checkpoint, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
checkpoint, device_map='auto',
trust_remote_code=True).cuda().eval()
self.model.generation_config = GenerationConfig.from_pretrained(
checkpoint, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
self.model.generation_config.do_sample = False # greedy
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_one(self, prompt, stop):
# "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。
prompt = prompt.replace('Human:',
'_HUMAN_:').replace('Assistant:',
'_ASSISTANT_:')
stop = [
item.replace('Human:', '_HUMAN_:').replace('Assistant:',
'_ASSISTANT_:')
for item in stop
]
result, _ = self.model.chat(self.tokenizer, prompt, history=None)
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[:-len(stop_seq)]
result = result.replace('_HUMAN_:',
'Human:').replace('_ASSISTANT_:', 'Assistant:')
return result
def load_models_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path,
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path,
device_map='auto',
trust_remote_code=True,
bf16=True,
use_flash_attn=True).eval()
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True)
model.generation_config.do_sample = False # use greedy decoding
return model, tokenizer
def load_jobs(filename):
jobs = []
with jsonlines.open(os.path.join(data_root_path, filename),
mode='r') as reader:
for job in reader:
jobs.append(job)
return jobs
def react_inference(filename, model, tokenizer):
filename_cache = filename + '.cache'
if os.path.exists(os.path.join(data_root_path, filename_cache)):
jobs = load_jobs(filename=filename_cache)
print('Loaded from', filename_cache)
else:
with open(os.path.join(data_root_path, filename_cache), 'w') as f:
jobs = load_jobs(filename=filename)
print('Inference:', filename)
for job in tqdm(jobs):
response, history = model.chat(tokenizer,
job['prompt'],
history=None)
job['gen'] = [response]
f.writelines(json.dumps(job, ensure_ascii=False) + '\n')
print(filename_cache, 'is saved.')
return jobs
def main(args):
print('loading model weights')
if args.checkpoint_path is not None:
model, tokenizer = load_models_tokenizer(args)
else:
model, tokenizer = None, None
print('model loaded')
result = {}
# eval react positive
if args.eval_react_positive:
print('eval react positive ...')
acc_count = 0
rouge_mean = 0
jobs = react_inference(filename=args.eval_react_positive_filename,
model=model,
tokenizer=tokenizer)
for job in jobs:
if eval_action(job):
acc_count += 1
rouge = eval_action_input(job, tokenizer)
rouge_mean += (rouge / len(jobs))
scores = {
'action_right_rate': acc_count / len(jobs),
'action_input_rouge': rouge_mean,
}
result.update({'react_positive': scores})
# eval react negative
if args.eval_react_negative:
print('eval react negative ...')
bad_count = 0
jobs = react_inference(filename=args.eval_react_negative_filename,
model=model,
tokenizer=tokenizer)
for job in jobs:
if '\nAction:' in job['gen'][0]:
bad_count += 1
scores = {'bad_rate': bad_count / len(jobs)}
result.update({'react_negative': scores})
# eval hfagent
if args.eval_hfagent:
print('eval hfagent ...')
agent = QWenAgent(model=model, tokenizer=tokenizer)
scores = evaluate_agent(agent, verbose=False, return_errors=False)
result.update({'hfagent': scores})
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(result)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test HF checkpoint.')
parser.add_argument('-c',
'--checkpoint-path',
type=str,
help='Checkpoint path',
default='Qwen/Qwen-7B-Chat')
parser.add_argument('-s',
'--seed',
type=int,
default=1234,
help='Random seed')
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='Evaluation options')
group.add_argument('--eval-react-positive',
action='store_true',
default=False,
help='Eval react positive.')
group.add_argument('--eval-react-positive-filename',
type=str,
default='exam_plugin_v1_react_positive.jsonl',
help='Eval react positive filename.')
group.add_argument('--eval-react-negative',
action='store_true',
default=False,
help='Eval react negative.')
group.add_argument('--eval-react-negative-filename',
type=str,
default='exam_plugin_v1_react_negative.jsonl',
help='Eval react negative filename.')
group.add_argument('--eval-hfagent',
action='store_true',
default=False,
help='Eval hfagent.')
args = parser.parse_args()
set_seed(args.seed)
main(args)

@ -311,13 +311,13 @@ LLMs have shown capability in coordinating multiple external systems to achieve
Qwen supports calling plugins/tools/APIs through [ReAct Prompting](https://arxiv.org/abs/2210.03629).
ReAct is also one of the main approaches used by the [LangChain](https://python.langchain.com/) framework.
For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md).
In the soon-to-be-released evaluation benchmark for assessing tool usage capabilities, Qwen's performance is as follows:
In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, Qwen's performance is as follows:
| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
| :---------- | --------------------------: | -------------------------: | -------------------------: |
| GPT-4 | 95% | **0.90** | 15.0% |
| GPT-3.5 | 85% | 0.88 | 75.0% |
| **Qwen-7B** | **99%** | 0.89 | **8.5%** |
| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
> The plugins that appear in the evaluation set do not appear in the training set of Qwen.
> This benchmark evaluates the accuracy of the model in selecting the correct plugin from multiple candidate plugins, the rationality of the parameters passed into the plugin, and the false positive rate.

Loading…
Cancel
Save