From 3006ef34e99a09b02e1fed3b3076b2de2a183cd9 Mon Sep 17 00:00:00 2001 From: hanpeng Date: Sat, 12 Aug 2023 09:30:19 +0800 Subject: [PATCH] Add openai_api --- openai_api.py | 209 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 openai_api.py diff --git a/openai_api.py b/openai_api.py new file mode 100644 index 0000000..3ed03d1 --- /dev/null +++ b/openai_api.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) +# Usage: python openai_api.py +# Visit http://localhost:8000/docs for documents. + +from argparse import ArgumentParser +import time +import torch +import uvicorn +from pydantic import BaseModel, Field +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Literal, Optional, Union +from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM +from transformers.generation import GenerationConfig +from sse_starlette.sse import ServerSentEvent, EventSourceResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system"] + content: str + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + max_length: Optional[int] = None + stream: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + + +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + global model_args + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + global model, tokenizer + + if request.messages[-1].role != "user": + raise HTTPException(status_code=400, detail="Invalid request") + query = request.messages[-1].content + + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == "system": + query = prev_messages.pop(0).content + query + + history = [] + if len(prev_messages) % 2 == 0: + for i in range(0, len(prev_messages), 2): + if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i+1].content]) + + if request.stream: + generate = predict(query, history, request.model) + return EventSourceResponse(generate, media_type="text/event-stream") + + response, _ = model.chat_stream(tokenizer, query, history=history) + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop" + ) + + return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") + + +async def predict(query: str, history: List[List[str]], model_id: str): + global model, tokenizer + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + current_length = 0 + + for new_response in model.chat_stream(tokenizer, query, history): + if len(new_response) == current_length: + continue + + new_text = new_response[current_length:] + current_length = len(new_response) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=new_text), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + yield '[DONE]' + +def _get_args(): + parser = ArgumentParser() + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") + parser.add_argument("--server-port", type=int, default=8000, + help="Demo server port.") + parser.add_argument("--server-name", type=str, default="127.0.0.1", + help="Demo server name.") + + args = parser.parse_args() + return args + +DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' + +if __name__ == "__main__": + args = _get_args() + + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + if args.cpu_only: + device_map = "cpu" + else: + device_map = "auto" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) \ No newline at end of file