You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

593 lines
20 KiB
Python

# Requirement:
# pip install "openai<1.0"
# Usage:
# python openai_api.py
1 year ago
# Visit http://localhost:8000/docs for documents.
import base64
import copy
import json
1 year ago
import time
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from pprint import pprint
from typing import Dict, List, Literal, Optional, Union
1 year ago
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
1 year ago
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
1 year ago
class BasicAuthMiddleware(BaseHTTPMiddleware):
1 year ago
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(
f'{username}:{password}'.encode()).decode()
1 year ago
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get('Authorization')
1 year ago
if authorization:
try:
schema, credentials = authorization.split()
if credentials == self.required_credentials:
return await call_next(request)
except ValueError:
pass
1 year ago
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False):
global args
if args.disable_gc and not forced:
return
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
1 year ago
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
1 year ago
yield
_gc(forced=True)
1 year ago
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
1 year ago
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
1 year ago
)
1 year ago
class ModelCard(BaseModel):
id: str
object: str = 'model'
1 year ago
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = 'owner'
1 year ago
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = 'list'
1 year ago
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal['user', 'assistant', 'system', 'function']
content: Optional[str]
function_call: Optional[Dict] = None
1 year ago
class DeltaMessage(BaseModel):
role: Optional[Literal['user', 'assistant', 'system']] = None
1 year ago
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
functions: Optional[List[Dict]] = None
1 year ago
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
1 year ago
max_length: Optional[int] = None
stream: Optional[bool] = False
stop: Optional[List[str]] = None
1 year ago
class ChatCompletionResponseChoice(BaseModel):
index: int
message: Union[ChatMessage]
finish_reason: Literal['stop', 'length', 'function_call']
1 year ago
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal['stop', 'length']]
1 year ago
class ChatCompletionResponse(BaseModel):
model: str
object: Literal['chat.completion', 'chat.completion.chunk']
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
1 year ago
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get('/v1/models', response_model=ModelList)
1 year ago
async def list_models():
global model_args
model_card = ModelCard(id='gpt-3.5-turbo')
1 year ago
return ModelList(data=[model_card])
# To work around that unpleasant leading-\n tokenization issue!
def add_extra_stop_words(stop_words):
if stop_words:
_stop_words = []
_stop_words.extend(stop_words)
for x in stop_words:
s = x.lstrip('\n')
if s and (s not in _stop_words):
_stop_words.append(s)
return _stop_words
return stop_words
def trim_stop_words(response, stop_words):
if stop_words:
for stop in stop_words:
idx = response.find(stop)
if idx != -1:
response = response[:idx]
return response
TOOL_DESC = (
'{name_for_model}: Call this tool to interact with the {name_for_human} API.'
' What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}'
)
REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
{tools_text}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tools_name_text}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!"""
_TEXT_COMPLETION_CMD = object()
def parse_messages(messages, functions):
if all(m.role != 'user' for m in messages):
raise HTTPException(
status_code=400,
detail='Invalid request: Expecting at least one user message.',
)
messages = copy.deepcopy(messages)
if messages[0].role == 'system':
system = messages.pop(0).content.lstrip('\n').rstrip()
else:
system = 'You are a helpful assistant.'
if functions:
tools_text = []
tools_name_text = []
for func_info in functions:
name = func_info.get('name', '')
name_m = func_info.get('name_for_model', name)
name_h = func_info.get('name_for_human', name)
desc = func_info.get('description', '')
desc_m = func_info.get('description_for_model', desc)
tool = TOOL_DESC.format(
name_for_model=name_m,
name_for_human=name_h,
# Hint: You can add the following format requirements in description:
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model=desc_m,
parameters=json.dumps(func_info['parameters'],
ensure_ascii=False),
)
tools_text.append(tool)
tools_name_text.append(name_m)
tools_text = '\n\n'.join(tools_text)
tools_name_text = ', '.join(tools_name_text)
instruction = (REACT_INSTRUCTION.format(
tools_text=tools_text,
tools_name_text=tools_name_text,
).lstrip('\n').rstrip())
else:
instruction = ''
messages_with_fncall = messages
messages = []
for m_idx, m in enumerate(messages_with_fncall):
role, content, func_call = m.role, m.content, m.function_call
content = content or ''
content = content.lstrip('\n').rstrip()
if role == 'function':
if (len(messages) == 0) or (messages[-1].role != 'assistant'):
raise HTTPException(
status_code=400,
detail=
'Invalid request: Expecting role assistant before role function.',
)
messages[-1].content += f'\nObservation: {content}'
if m_idx == len(messages_with_fncall) - 1:
# add a prefix for text completion
messages[-1].content += '\nThought:'
elif role == 'assistant':
if len(messages) == 0:
raise HTTPException(
status_code=400,
detail=
'Invalid request: Expecting role user before role assistant.',
)
if func_call is None:
if functions:
content = f'Thought: I now know the final answer.\nFinal Answer: {content}'
else:
f_name, f_args = func_call['name'], func_call['arguments']
if not content.startswith('Thought:'):
content = f'Thought: {content}'
content = f'{content}\nAction: {f_name}\nAction Input: {f_args}'
if messages[-1].role == 'user':
messages.append(
ChatMessage(role='assistant',
content=content.lstrip('\n').rstrip()))
else:
messages[-1].content += '\n' + content
elif role == 'user':
messages.append(
ChatMessage(role='user',
content=content.lstrip('\n').rstrip()))
else:
raise HTTPException(
status_code=400,
detail=f'Invalid request: Incorrect role {role}.')
query = _TEXT_COMPLETION_CMD
if messages[-1].role == 'user':
query = messages[-1].content
messages = messages[:-1]
if len(messages) % 2 != 0:
raise HTTPException(status_code=400, detail='Invalid request')
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range(0, len(messages), 2):
if messages[i].role == 'user' and messages[i + 1].role == 'assistant':
usr_msg = messages[i].content.lstrip('\n').rstrip()
bot_msg = messages[i + 1].content.lstrip('\n').rstrip()
if instruction and (i == len(messages) - 2):
usr_msg = f'{instruction}\n\nQuestion: {usr_msg}'
instruction = ''
history.append([usr_msg, bot_msg])
else:
raise HTTPException(
status_code=400,
detail=
'Invalid request: Expecting exactly one user (or function) role before every assistant role.',
)
if instruction:
assert query is not _TEXT_COMPLETION_CMD
query = f'{instruction}\n\nQuestion: {query}'
return query, history, system
def parse_response(response):
func_name, func_args = '', ''
i = response.find('\nAction:')
j = response.find('\nAction Input:')
k = response.find('\nObservation:')
if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
response = response.rstrip() + '\nObservation:' # Add it back.
k = response.find('\nObservation:')
func_name = response[i + len('\nAction:'):j].strip()
func_args = response[j + len('\nAction Input:'):k].strip()
if func_name:
response = response[:i]
t = response.find('Thought: ')
if t >= 0:
response = response[t + len('Thought: '):]
response = response.strip()
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role='assistant',
content=response,
function_call={
'name': func_name,
'arguments': func_args
},
),
finish_reason='function_call',
)
return choice_data
z = response.rfind('\nFinal Answer: ')
if z >= 0:
response = response[z + len('\nFinal Answer: '):]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response),
finish_reason='stop',
)
return choice_data
# completion mode, not chat mode
def text_complete_last_message(history, stop_words_ids, gen_kwargs, system):
im_start = '<|im_start|>'
im_end = '<|im_end|>'
prompt = f'{im_start}system\n{system}{im_end}'
for i, (query, response) in enumerate(history):
query = query.lstrip('\n').rstrip()
response = response.lstrip('\n').rstrip()
prompt += f'\n{im_start}user\n{query}{im_end}'
prompt += f'\n{im_start}assistant\n{response}{im_end}'
prompt = prompt[:-len(im_end)]
_stop_words_ids = [tokenizer.encode(im_end)]
if stop_words_ids:
for s in stop_words_ids:
_stop_words_ids.append(s)
stop_words_ids = _stop_words_ids
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
output = model.generate(input_ids,
stop_words_ids=stop_words_ids,
**gen_kwargs).tolist()[0]
output = tokenizer.decode(output, errors='ignore')
assert output.startswith(prompt)
output = output[len(prompt):]
output = trim_stop_words(output, ['<|endoftext|>', im_end])
print(f'<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>')
return output
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
1 year ago
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
gen_kwargs = {}
if request.top_k is not None:
gen_kwargs['top_k'] = request.top_k
if request.temperature is not None:
if request.temperature < 0.01:
gen_kwargs['top_k'] = 1 # greedy decoding
else:
# Not recommended. Please tune top_p instead.
gen_kwargs['temperature'] = request.temperature
if request.top_p is not None:
gen_kwargs['top_p'] = request.top_p
stop_words = add_extra_stop_words(request.stop)
if request.functions:
stop_words = stop_words or []
if 'Observation:' not in stop_words:
stop_words.append('Observation:')
query, history, system = parse_messages(request.messages,
request.functions)
1 year ago
if request.stream:
if request.functions:
raise HTTPException(
status_code=400,
detail=
'Invalid request: Function calling is not yet implemented for stream mode.',
)
generate = predict(query,
history,
request.model,
stop_words,
gen_kwargs,
system=system)
return EventSourceResponse(generate, media_type='text/event-stream')
stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None
if query is _TEXT_COMPLETION_CMD:
response = text_complete_last_message(history,
stop_words_ids=stop_words_ids,
gen_kwargs=gen_kwargs,
system=system)
else:
response, _ = model.chat(
tokenizer,
query,
history=history,
system=system,
stop_words_ids=stop_words_ids,
**gen_kwargs,
)
print('<chat>')
pprint(history, indent=2)
print(f'{query}\n<!-- *** -->\n{response}\n</chat>')
_gc()
response = trim_stop_words(response, stop_words)
if request.functions:
choice_data = parse_response(response)
else:
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response),
finish_reason='stop',
)
return ChatCompletionResponse(model=request.model,
choices=[choice_data],
object='chat.completion')
1 year ago
def _dump_json(data: BaseModel, *args, **kwargs) -> str:
try:
return data.model_dump_json(*args, **kwargs)
except AttributeError: # pydantic<2.0.0
return data.json(*args, **kwargs) # noqa
async def predict(
query: str,
history: List[List[str]],
model_id: str,
stop_words: List[str],
gen_kwargs: Dict,
system: str,
):
1 year ago
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role='assistant'), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
1 year ago
current_length = 0
stop_words_ids = [tokenizer.encode(s)
for s in stop_words] if stop_words else None
if stop_words:
# TODO: It's a little bit tricky to trim stop words in the stream mode.
raise HTTPException(
status_code=400,
detail=
'Invalid request: custom stop words are not yet supported for stream mode.',
)
response_generator = model.chat_stream(tokenizer,
query,
history=history,
stop_words_ids=stop_words_ids,
system=system,
**gen_kwargs)
for new_response in response_generator:
1 year ago
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(index=0,
delta=DeltaMessage(),
finish_reason='stop')
chunk = ChatCompletionResponse(model=model_id,
choices=[choice_data],
object='chat.completion.chunk')
yield '{}'.format(_dump_json(chunk, exclude_unset=True))
yield '[DONE]'
_gc()
1 year ago
def _get_args():
parser = ArgumentParser()
parser.add_argument(
'-c',
'--checkpoint-path',
type=str,
default='Qwen/Qwen-7B-Chat',
help='Checkpoint name or path, default to %(default)r',
)
parser.add_argument('--api-auth', help='API authentication credentials')
parser.add_argument('--cpu-only',
action='store_true',
help='Run demo with CPU only')
parser.add_argument('--server-port',
type=int,
default=8000,
help='Demo server port.')
1 year ago
parser.add_argument(
'--server-name',
type=str,
default='127.0.0.1',
help=
'Demo server name. Default: 127.0.0.1, which is only visible from the local computer.'
' If you want other computers to access your server, use 0.0.0.0 instead.',
)
parser.add_argument(
'--disable-gc',
action='store_true',
help='Disable GC after each response generated.',
)
1 year ago
args = parser.parse_args()
return args
if __name__ == '__main__':
1 year ago
args = _get_args()
1 year ago
1 year ago
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path,
trust_remote_code=True,
resume_download=True,
1 year ago
)
1 year ago
if args.api_auth:
app.add_middleware(BasicAuthMiddleware,
username=args.api_auth.split(':')[0],
password=args.api_auth.split(':')[1])
1 year ago
1 year ago
if args.cpu_only:
device_map = 'cpu'
1 year ago
else:
device_map = 'auto'
1 year ago
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
1 year ago
1 year ago
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path,
trust_remote_code=True,
resume_download=True,
1 year ago
)
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)