# coding=utf-8 # Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # Usage: python openai_api.py # Visit http://localhost:8000/docs for documents. from argparse import ArgumentParser import time import torch import uvicorn from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing import Any, Dict, List, Literal, Optional, Union from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM from transformers.generation import GenerationConfig from sse_starlette.sse import ServerSentEvent, EventSourceResponse @asynccontextmanager async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ModelCard(BaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "owner" root: Optional[str] = None parent: Optional[str] = None permission: Optional[list] = None class ModelList(BaseModel): object: str = "list" data: List[ModelCard] = [] class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None max_length: Optional[int] = None stream: Optional[bool] = False stop: Optional[List[str]] = [] class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length"] class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length"]] class ChatCompletionResponse(BaseModel): model: str object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) @app.get("/v1/models", response_model=ModelList) async def list_models(): global model_args model_card = ModelCard(id="gpt-3.5-turbo") return ModelList(data=[model_card]) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content stop_words = request.stop stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words)))) prev_messages = request.messages[:-1] # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query. # if len(prev_messages) > 0 and prev_messages[0].role == "system": # query = prev_messages.pop(0).content + query history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) else: raise HTTPException(status_code=400, detail="Invalid request.") else: raise HTTPException(status_code=400, detail="Invalid request.") if request.stream: generate = predict(query, history, request.model, stop_words) return EventSourceResponse(generate, media_type="text/event-stream") if stop_words: react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words] response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens) for stop_ in stop_words: if response.endswith(stop_): response = response[:response.find(stop_)] else: response, _ = model.chat(tokenizer, query, history=history) choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop" ) return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]): global model, tokenizer assert stop_words == [], "in stream format, stop word is output" choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.model_dump_json(exclude_unset=True)) current_length = 0 if stop_words: react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words] response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens) else: response_generator = model.chat_stream(tokenizer, query, history=history) for new_response in response_generator: if len(new_response) == current_length: continue new_text = new_response[current_length:] current_length = len(new_response) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.model_dump_json(exclude_unset=True)) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield '[DONE]' def _get_args(): parser = ArgumentParser() parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat', help="Checkpoint name or path, default to %(default)r") parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") parser.add_argument("--server-port", type=int, default=8000, help="Demo server port.") parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Demo server name.") args = parser.parse_args() return args if __name__ == "__main__": args = _get_args() tokenizer = AutoTokenizer.from_pretrained( args.checkpoint_path, trust_remote_code=True, resume_download=True, ) if args.cpu_only: device_map = "cpu" else: device_map = "auto" model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path, device_map=device_map, trust_remote_code=True, resume_download=True, ).eval() model.generation_config = GenerationConfig.from_pretrained( args.checkpoint_path, trust_remote_code=True, resume_download=True, ) uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)