You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
11 KiB
Python

from __future__ import annotations
import json
from typing import Callable, Optional, TypedDict
import aiohttp
import config
import numpy as np
from aiohttp_sse_client2 import client as sse_client
from service.tiktoken import TikTokenService
class ChatCompleteMessageLog(TypedDict):
role: str
content: str
class ChatCompleteResponse(TypedDict):
message: str
prompt_tokens: int
message_tokens: int
total_tokens: int
finish_reason: str
class OpenAIApi:
@staticmethod
def create():
return OpenAIApi()
def __init__(self):
if config.OPENAI_API_TYPE == "azure":
self.api_url = config.AZURE_OPENAI_ENDPOINT
self.api_key = config.AZURE_OPENAI_KEY
else:
self.api_url = config.OPENAI_API or "https://api.openai.com"
self.api_key = config.OPENAI_TOKEN
def build_header(self):
if config.OPENAI_API_TYPE == "azure":
return {
"Content-Type": "application/json",
"Accept": "application/json",
"api-key": self.api_key
}
else:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure":
if method == "completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method
elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method
else:
return self.api_url + "/v1/" + method
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
text_list = []
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
text_list.append(text)
token_usage = 0
async with aiohttp.ClientSession() as session:
url = self.get_url("embeddings")
params = {}
post_data = {
"input": text_list,
}
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "text-embedding-ada-002"
if config.OPENAI_API_TYPE == "azure":
# Azure api does not support batch
for index, text in enumerate(text_list):
async with session.post(url,
headers=self.build_header(),
params=params,
json={"input": text},
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
one_data = data["data"]
if len(one_data) > 0:
embedding = one_data[0]["embedding"]
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage += int(data["usage"]["total_tokens"])
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
else:
async with session.post(url,
headers=self.build_header(),
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
for one_data in data["data"]:
embedding = one_data["embedding"]
index = one_data["index"]
if index < len(doc_list):
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage = int(data["usage"]["total_tokens"])
await on_index_progress(index, len(text_list))
return (doc_list, token_usage)
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
summaryContent = None
messageList: list[ChatCompleteMessageLog] = []
for message in conversation:
if message["role"] == "summary":
summaryContent = message["content"]
elif message["role"] == "user" or message["role"] == "assistant":
messageList.append(message)
if summaryContent:
system_prompt += "\n\n" + summaryContent
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
messageList.append(ChatCompleteMessageLog(role="user", content=question))
return messageList
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation)
url = self.get_url("completions")
params = {}
post_data = {
"messages": messageList,
"user": user,
}
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
post_data = {k: v for k, v in post_data.items() if v is not None}
async with aiohttp.ClientSession() as session:
async with session.post(url,
headers=self.build_header,
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0]
message = choice["message"]["content"]
finish_reason = choice["finish_reason"]
prompt_tokens = int(data["usage"]["prompt_tokens"])
message_tokens = int(data["usage"]["completion_tokens"])
total_tokens = int(data["usage"]["total_tokens"])
return ChatCompleteResponse(message=message,
prompt_tokens=prompt_tokens,
message_tokens=message_tokens,
total_tokens=total_tokens,
finish_reason=finish_reason)
return None
async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
tiktoken = await TikTokenService.create()
messageList = await self.make_message_list(question, system_prompt, conversation)
prompt_tokens = 0
for message in messageList:
prompt_tokens += await tiktoken.get_tokens(message["content"])
params = {
"model": "gpt-3.5-turbo",
"messages": messageList,
"stream": True,
"user": user,
}
params = {k: v for k, v in params.items() if v is not None}
res_message: list[str] = []
finish_reason = None
async with sse_client.EventSource(
self.api_url + "/v1/chat/completions",
option={
"method": "POST"
},
headers={"Authorization": f"Bearer {self.api_key}"},
json=params,
proxy=config.REQUEST_PROXY
) as session:
async for event in session:
"""
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Thank"},"index":0,"finish_reason":null}]}
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
[DONE]
"""
content_started = False
if event.data == "[DONE]":
break
elif event.data[0] == "{" and event.data[-1] == "}":
data = json.loads(event.data)
if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0]
if choice["finish_reason"] is not None:
finish_reason = choice["finish_reason"]
delta_content = choice["delta"]
if "content" in delta_content:
delta_message: str = delta_content["content"]
# Skip empty lines before content
if not content_started:
if delta_message.replace("\n", "") == "":
continue
else:
content_started = True
res_message.append(delta_message)
if config.DEBUG:
print(delta_message, end="", flush=True)
if on_message is not None:
await on_message(delta_message)
res_message_str = "".join(res_message)
message_tokens = await tiktoken.get_tokens(res_message_str)
total_tokens = prompt_tokens + message_tokens
return ChatCompleteResponse(message=res_message_str,
prompt_tokens=prompt_tokens,
message_tokens=message_tokens,
total_tokens=total_tokens,
finish_reason=finish_reason)