From 3006ef34e99a09b02e1fed3b3076b2de2a183cd9 Mon Sep 17 00:00:00 2001 From: hanpeng Date: Sat, 12 Aug 2023 09:30:19 +0800 Subject: [PATCH 1/8] Add openai_api --- openai_api.py | 209 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 openai_api.py diff --git a/openai_api.py b/openai_api.py new file mode 100644 index 0000000..3ed03d1 --- /dev/null +++ b/openai_api.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) +# Usage: python openai_api.py +# Visit http://localhost:8000/docs for documents. + +from argparse import ArgumentParser +import time +import torch +import uvicorn +from pydantic import BaseModel, Field +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Literal, Optional, Union +from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM +from transformers.generation import GenerationConfig +from sse_starlette.sse import ServerSentEvent, EventSourceResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system"] + content: str + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + max_length: Optional[int] = None + stream: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + + +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + global model_args + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + global model, tokenizer + + if request.messages[-1].role != "user": + raise HTTPException(status_code=400, detail="Invalid request") + query = request.messages[-1].content + + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == "system": + query = prev_messages.pop(0).content + query + + history = [] + if len(prev_messages) % 2 == 0: + for i in range(0, len(prev_messages), 2): + if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i+1].content]) + + if request.stream: + generate = predict(query, history, request.model) + return EventSourceResponse(generate, media_type="text/event-stream") + + response, _ = model.chat_stream(tokenizer, query, history=history) + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop" + ) + + return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") + + +async def predict(query: str, history: List[List[str]], model_id: str): + global model, tokenizer + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + current_length = 0 + + for new_response in model.chat_stream(tokenizer, query, history): + if len(new_response) == current_length: + continue + + new_text = new_response[current_length:] + current_length = len(new_response) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=new_text), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + yield '[DONE]' + +def _get_args(): + parser = ArgumentParser() + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") + parser.add_argument("--server-port", type=int, default=8000, + help="Demo server port.") + parser.add_argument("--server-name", type=str, default="127.0.0.1", + help="Demo server name.") + + args = parser.parse_args() + return args + +DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' + +if __name__ == "__main__": + args = _get_args() + + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + if args.cpu_only: + device_map = "cpu" + else: + device_map = "auto" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) \ No newline at end of file From dfced1aec000e434de3b980fc8244f4bb2236562 Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:28:05 +0800 Subject: [PATCH 2/8] Update openai_api.py --- openai_api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/openai_api.py b/openai_api.py index 3ed03d1..8b12c61 100644 --- a/openai_api.py +++ b/openai_api.py @@ -35,6 +35,7 @@ app.add_middleware( allow_headers=["*"], ) + class ModelCard(BaseModel): id: str object: str = "model" @@ -136,7 +137,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.model_dump_json(exclude_unset=True)) current_length = 0 @@ -154,7 +154,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.model_dump_json(exclude_unset=True)) @@ -164,13 +163,12 @@ async def predict(query: str, history: List[List[str]], model_id: str): finish_reason="stop" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield '[DONE]' def _get_args(): parser = ArgumentParser() - parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat', help="Checkpoint name or path, default to %(default)r") parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") parser.add_argument("--server-port", type=int, default=8000, @@ -181,7 +179,6 @@ def _get_args(): args = parser.parse_args() return args -DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' if __name__ == "__main__": args = _get_args() @@ -206,4 +203,4 @@ if __name__ == "__main__": args.checkpoint_path, trust_remote_code=True, resume_download=True, ) - uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) \ No newline at end of file + uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) From afc571987f4acd9a00010136de970b14fe7f7f9f Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:38:39 +0800 Subject: [PATCH 3/8] Update README.md --- README.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/README.md b/README.md index d6b28cd..8bd0e07 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,36 @@ Then run the command below and click on the generated link: python web_demo.py ``` +## API +We provide methods to deploy local API based on OpenAI API (thanks to @hanpenggit). Before you start, install the required packages: + +``` +pip install fastapi uvicorn openai pydantic sse_starlette +``` +Then run the command to deploy your API: +``` +python openai_api.py +``` +You can change your arguments, e.g., `-c` for checkpoint name or path, `--cpu-only` for CPU deployment, etc. If you meet problems launching your API deployment, updating the packages to the latest version can probably solve them. + +Using the API is also simple. See the example below: + +``` +import openai +openai.api_base = "http://localhost:8000/v1" +openai.api_key = "none" +for chunk in openai.ChatCompletion.create( + model="Qwen-7B", + messages=[ + {"role": "user", "content": "你好"} + ], + stream=True +): + if hasattr(chunk.choices[0].delta, "content"): + print(chunk.choices[0].delta.content, end="", flush=True) +``` + + ## Tool Usage Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance. From 53790354987bc4dab99c1db75f36873e93097c63 Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:41:26 +0800 Subject: [PATCH 4/8] Update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8bd0e07..81517bb 100644 --- a/README.md +++ b/README.md @@ -269,18 +269,18 @@ python web_demo.py ## API We provide methods to deploy local API based on OpenAI API (thanks to @hanpenggit). Before you start, install the required packages: -``` +```bash pip install fastapi uvicorn openai pydantic sse_starlette ``` Then run the command to deploy your API: -``` +```bash python openai_api.py ``` You can change your arguments, e.g., `-c` for checkpoint name or path, `--cpu-only` for CPU deployment, etc. If you meet problems launching your API deployment, updating the packages to the latest version can probably solve them. Using the API is also simple. See the example below: -``` +```python import openai openai.api_base = "http://localhost:8000/v1" openai.api_key = "none" From 0bb1df0bc28c17232cf330609bc8f4f2bf7dde9d Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:44:19 +0800 Subject: [PATCH 5/8] Update README_CN.md --- README_CN.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/README_CN.md b/README_CN.md index e844d50..eccf1f6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -268,6 +268,36 @@ pip install -r requirements_web_demo.txt python web_demo.py ``` +## API +我们提供了OpenAI API格式的本地API部署方法(感谢@hanpenggit)。在开始之前先安装必要的代码库: + +```bash +pip install fastapi uvicorn openai pydantic sse_starlette +``` +随后即可运行以下命令部署你的本地API: +```bash +python openai_api.py +``` +你也可以修改参数,比如`-c`来修改模型名称或路径, `--cpu-only`改为CPU部署等等。如果部署出现问题,更新上述代码库往往可以解决大多数问题。 + +使用API同样非常简单,示例如下: + +```python +import openai +openai.api_base = "http://localhost:8000/v1" +openai.api_key = "none" +for chunk in openai.ChatCompletion.create( + model="Qwen-7B", + messages=[ + {"role": "user", "content": "你好"} + ], + stream=True +): + if hasattr(chunk.choices[0].delta, "content"): + print(chunk.choices[0].delta.content, end="", flush=True) +``` + + ## 工具调用 Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。在我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力,并发现Qwen-7B-Chat能够取得稳定的表现。 From f91e7b3c872cd902acd366335750acd34ba85430 Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:49:26 +0800 Subject: [PATCH 6/8] Update README_JA.md --- README_JA.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/README_JA.md b/README_JA.md index 7d5ad32..2cf61f5 100644 --- a/README_JA.md +++ b/README_JA.md @@ -273,6 +273,38 @@ pip install -r requirements_web_demo.txt python web_demo.py ``` +## API +OpenAI APIをベースにローカルAPIをデプロイする方法を提供する(@hanpenggitに感謝)。始める前に、必要なパッケージをインストールしてください: + +```bash +pip install fastapi uvicorn openai pydantic sse_starlette +``` + +それから、APIをデプロイするコマンドを実行する: + +```bash +python openai_api.py +``` + +チェックポイント名やパスには `-c` 、CPU デプロイメントには `--cpu-only` など、引数を変更できます。APIデプロイメントを起動する際に問題が発生した場合は、パッケージを最新バージョンに更新することで解決できる可能性があります。 + +APIの使い方も簡単だ。以下の例をご覧ください: + +```python +import openai +openai.api_base = "http://localhost:8000/v1" +openai.api_key = "none" +for chunk in openai.ChatCompletion.create( + model="Qwen-7B", + messages=[ + {"role": "user", "content": "你好"} + ], + stream=True +): + if hasattr(chunk.choices[0].delta, "content"): + print(chunk.choices[0].delta.content, end="", flush=True) +``` + ## ツールの使用 Qwen-7B-Chat は、API、データベース、モデルなど、ツールの利用に特化して最適化されており、ユーザは独自の Qwen-7B ベースの LangChain、エージェント、コードインタプリタを構築することができます。ツール利用能力を評価するための評価[ベンチマーク](eval/EVALUATION.md)では、Qwen-7B は安定した性能に達しています。 From c84d2f685ab733751874c88ab3a8268f1efdd62d Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:54:53 +0800 Subject: [PATCH 7/8] add exception --- openai_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/openai_api.py b/openai_api.py index 8b12c61..bae3e90 100644 --- a/openai_api.py +++ b/openai_api.py @@ -113,6 +113,10 @@ async def create_chat_completion(request: ChatCompletionRequest): for i in range(0, len(prev_messages), 2): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) + else: + raise HTTPException(status_code=400, detail="Invalid request.") + else: + raise HTTPException(status_code=400, detail="Invalid request.") if request.stream: generate = predict(query, history, request.model) From 31a496fd24c934b319c085fb81d5791a1c6fbb77 Mon Sep 17 00:00:00 2001 From: Junyang Lin Date: Sun, 13 Aug 2023 12:58:42 +0800 Subject: [PATCH 8/8] Update openai_api.py --- openai_api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/openai_api.py b/openai_api.py index bae3e90..568984f 100644 --- a/openai_api.py +++ b/openai_api.py @@ -105,8 +105,9 @@ async def create_chat_completion(request: ChatCompletionRequest): query = request.messages[-1].content prev_messages = request.messages[:-1] - if len(prev_messages) > 0 and prev_messages[0].role == "system": - query = prev_messages.pop(0).content + query + # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query. + # if len(prev_messages) > 0 and prev_messages[0].role == "system": + # query = prev_messages.pop(0).content + query history = [] if len(prev_messages) % 2 == 0: