You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
191 lines
7.7 KiB
Python
191 lines
7.7 KiB
Python
2 years ago
|
from __future__ import annotations
|
||
|
import json
|
||
|
from typing import TypedDict
|
||
|
|
||
|
import aiohttp
|
||
|
import config
|
||
|
import numpy as np
|
||
|
from aiohttp_sse_client2 import client as sse_client
|
||
|
|
||
|
from service.tiktoken import TikTokenService
|
||
|
|
||
|
class ChatCompleteMessageLog(TypedDict):
|
||
|
role: str
|
||
|
content: str
|
||
|
|
||
|
class ChatCompleteResponse(TypedDict):
|
||
|
message: str
|
||
|
prompt_tokens: int
|
||
|
message_tokens: int
|
||
|
total_tokens: int
|
||
|
finish_reason: str
|
||
|
|
||
|
class OpenAIApi:
|
||
|
@staticmethod
|
||
|
def create():
|
||
|
return OpenAIApi(config.OPENAI_API or "https://api.openai.com", config.OPENAI_TOKEN)
|
||
|
|
||
|
def __init__(self, api_url: str, token: str):
|
||
|
self.api_url = api_url
|
||
|
self.token = token
|
||
|
|
||
|
async def get_embeddings(self, doc_list: list):
|
||
|
token_usage = 0
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
text_list = [doc["text"] for doc in doc_list]
|
||
|
params = {
|
||
|
"model": "text-embedding-ada-002",
|
||
|
"input": text_list,
|
||
|
}
|
||
|
async with session.post(self.api_url + "/v1/embeddings",
|
||
|
headers={"Authorization": f"Bearer {self.token}"},
|
||
|
json=params,
|
||
|
timeout=30,
|
||
|
proxy=config.REQUEST_PROXY) as resp:
|
||
|
|
||
|
data = await resp.json()
|
||
|
|
||
|
for one_data in data["data"]:
|
||
|
embedding = one_data["embedding"]
|
||
|
index = one_data["index"]
|
||
|
|
||
|
if index < len(doc_list):
|
||
|
if embedding is not None:
|
||
|
embedding = np.array(embedding)
|
||
|
doc_list[index]["embedding"] = embedding
|
||
|
|
||
|
token_usage = int(data["usage"]["total_tokens"])
|
||
|
|
||
|
return (doc_list, token_usage)
|
||
|
|
||
|
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
|
||
|
summaryContent = None
|
||
|
|
||
|
messageList: list[ChatCompleteMessageLog] = []
|
||
|
for message in conversation:
|
||
|
if message["role"] == "summary":
|
||
|
summaryContent = message["content"]
|
||
|
elif message["role"] == "user" or message["role"] == "assistant":
|
||
|
messageList.append(message)
|
||
|
|
||
|
if summaryContent:
|
||
|
system_prompt += "\n\n" + summaryContent
|
||
|
|
||
|
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
|
||
|
messageList.append(ChatCompleteMessageLog(role="user", content=question))
|
||
|
|
||
|
return messageList
|
||
|
|
||
|
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
|
||
|
messageList = await self.make_message_list(question, system_prompt, conversation)
|
||
|
|
||
|
params = {
|
||
|
"model": "gpt-3.5-turbo",
|
||
|
"messages": messageList,
|
||
|
"user": user,
|
||
|
}
|
||
|
params = {k: v for k, v in params.items() if v is not None}
|
||
|
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.post(self.api_url + "/v1/chat/completions",
|
||
|
headers={"Authorization": f"Bearer {self.token}"},
|
||
|
json=params,
|
||
|
timeout=30,
|
||
|
proxy=config.REQUEST_PROXY) as resp:
|
||
|
|
||
|
data = await resp.json()
|
||
|
|
||
|
if "choices" in data and len(data["choices"]) > 0:
|
||
|
choice = data["choices"][0]
|
||
|
|
||
|
message = choice["message"]["content"]
|
||
|
finish_reason = choice["finish_reason"]
|
||
|
|
||
|
prompt_tokens = int(data["usage"]["prompt_tokens"])
|
||
|
message_tokens = int(data["usage"]["completion_tokens"])
|
||
|
total_tokens = int(data["usage"]["total_tokens"])
|
||
|
|
||
|
return ChatCompleteResponse(message=message,
|
||
|
prompt_tokens=prompt_tokens,
|
||
|
message_tokens=message_tokens,
|
||
|
total_tokens=total_tokens,
|
||
|
finish_reason=finish_reason)
|
||
|
|
||
|
return None
|
||
|
|
||
|
async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
|
||
|
tiktoken = await TikTokenService.create()
|
||
|
|
||
|
messageList = await self.make_message_list(question, system_prompt, conversation)
|
||
|
|
||
|
prompt_tokens = 0
|
||
|
for message in messageList:
|
||
|
prompt_tokens += await tiktoken.get_tokens(message["content"])
|
||
|
|
||
|
params = {
|
||
|
"model": "gpt-3.5-turbo",
|
||
|
"messages": messageList,
|
||
|
"stream": True,
|
||
|
"user": user,
|
||
|
}
|
||
|
params = {k: v for k, v in params.items() if v is not None}
|
||
|
|
||
|
res_message: list[str] = []
|
||
|
finish_reason = None
|
||
|
|
||
|
async with sse_client.EventSource(
|
||
|
self.api_url + "/v1/chat/completions",
|
||
|
option={
|
||
|
"method": "POST"
|
||
|
},
|
||
|
headers={"Authorization": f"Bearer {self.token}"},
|
||
|
json=params,
|
||
|
proxy=config.REQUEST_PROXY
|
||
|
) as session:
|
||
|
async for event in session:
|
||
|
"""
|
||
|
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
|
||
|
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Thank"},"index":0,"finish_reason":null}]}
|
||
|
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
|
||
|
[DONE]
|
||
|
"""
|
||
|
content_started = False
|
||
|
|
||
|
if event.data == "[DONE]":
|
||
|
break
|
||
|
elif event.data[0] == "{" and event.data[-1] == "}":
|
||
|
data = json.loads(event.data)
|
||
|
if "choices" in data and len(data["choices"]) > 0:
|
||
|
choice = data["choices"][0]
|
||
|
|
||
|
if choice["finish_reason"] is not None:
|
||
|
finish_reason = choice["finish_reason"]
|
||
|
|
||
|
delta_content = choice["delta"]
|
||
|
if "content" in delta_content:
|
||
|
delta_message: str = delta_content["content"]
|
||
|
|
||
|
# Skip empty lines before content
|
||
|
if not content_started:
|
||
|
if delta_message.replace("\n", "") == "":
|
||
|
continue
|
||
|
else:
|
||
|
content_started = True
|
||
|
|
||
|
res_message.append(delta_message)
|
||
|
|
||
|
if config.DEBUG:
|
||
|
print(delta_message, end="", flush=True)
|
||
|
|
||
|
if on_message is not None:
|
||
|
await on_message(delta_message)
|
||
|
|
||
|
res_message_str = "".join(res_message)
|
||
|
message_tokens = await tiktoken.get_tokens(res_message_str)
|
||
|
total_tokens = prompt_tokens + message_tokens
|
||
|
|
||
|
return ChatCompleteResponse(message=res_message_str,
|
||
|
prompt_tokens=prompt_tokens,
|
||
|
message_tokens=message_tokens,
|
||
|
total_tokens=total_tokens,
|
||
|
finish_reason=finish_reason)
|