You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

328 lines
13 KiB
Python

from __future__ import annotations
import asyncio
import json
from typing import Callable, Optional, TypedDict
import aiohttp
from libs.config import Config
import numpy as np
from aiohttp_sse_client2 import client as sse_client
from service.tiktoken import TikTokenService
AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview"
AZURE_EMBEDDING_API_VERSION = "2023-05-15"
class OpenAIApiTypeInvalidException(Exception):
def __init__(self, api_id: str):
self.api_id = api_id
def __str__(self):
return f"Invalid api_id: {self.api_id}"
class ChatCompleteMessageLog(TypedDict):
role: str
content: str
class ChatCompleteResponse(TypedDict):
message: str
prompt_tokens: int
message_tokens: int
total_tokens: int
finish_reason: str
class OpenAIApi:
@staticmethod
def create(api_id: str) -> OpenAIApi:
return OpenAIApi(api_id)
def __init__(self, api_id: str):
self.api_id = api_id
self.api_type = Config.get(f"chatcomplete.{api_id}.api_type", None, str)
if self.api_type is None:
raise OpenAIApiTypeInvalidException(api_id)
self.request_proxy = Config.get(f"chatcomplete.{api_id}.request_proxy", type=str, empty_is_none=True)
self.api_url = Config.get(f"chatcomplete.{api_id}.api_endpoint", type=str)
self.api_key = Config.get(f"chatcomplete.{api_id}.key", type=str)
def build_header(self):
if self.api_type == "azure":
return {
"content-type": "application/json",
"accept": "application/json",
"api-key": self.api_key
}
else:
return {
"authorization": f"Bearer {self.api_key}",
"content-type": "application/json",
"accept": "application/json",
}
def get_url(self, method: str):
if self.api_type == "azure":
if method == "chat/completions":
deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_chatcomplete")
if deployment is None:
raise Exception("deployment for chatcomplete is not set")
return self.api_url + "/openai/deployments/" + deployment + "/" + method
elif method == "embeddings":
deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_embedding")
if deployment is None:
raise Exception("deployment for embedding is not set")
return self.api_url + "/openai/deployments/" + deployment + "/" + method
else:
return self.api_url + "/v1/" + method
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
token_usage = 0
text_list = [doc["text"] for doc in doc_list]
async with aiohttp.ClientSession() as session:
url = self.get_url("embeddings")
params = {}
post_data = {
"input": text_list,
}
if self.api_type == "azure":
params["api-version"] = AZURE_EMBEDDING_API_VERSION
else:
post_data["model"] = "text-embedding-ada-002"
if self.api_type == "azure":
# Azure api does not support batch
for index, text in enumerate(text_list):
retry_num = 0
max_retry_num = 3
while retry_num < max_retry_num:
try:
async with session.post(url,
headers=self.build_header(),
params=params,
json={"input": text},
timeout=30,
proxy=self.request_proxy) as resp:
data = await resp.json()
one_data = data["data"]
if len(one_data) > 0:
embedding = one_data[0]["embedding"]
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage += int(data["usage"]["total_tokens"])
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
break
except Exception as e:
retry_num += 1
if retry_num >= max_retry_num:
raise e
print("Error: %s" % e)
print("Retrying...")
await asyncio.sleep(0.5)
else:
async with session.post(url,
headers=self.build_header(),
params=params,
json=post_data,
timeout=30,
proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise Exception(data["error"])
for one_data in data["data"]:
embedding = one_data["embedding"]
index = one_data["index"]
if index < len(doc_list):
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage = int(data["usage"]["total_tokens"])
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
return (doc_list, token_usage)
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
summaryContent = None
messageList: list[ChatCompleteMessageLog] = []
for message in conversation:
if message["role"] == "summary":
summaryContent = message["content"]
elif message["role"] == "user" or message["role"] == "assistant":
messageList.append(message)
if summaryContent:
system_prompt += "\n\n" + summaryContent
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
messageList.append(ChatCompleteMessageLog(role="user", content=question))
return messageList
async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation)
url = self.get_url("chat/completions")
params = {}
post_data = {
"messages": messageList,
"user": user,
}
if self.api_type == "azure":
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
else:
post_data["model"] = model
post_data = {k: v for k, v in post_data.items() if v is not None}
async with aiohttp.ClientSession() as session:
async with session.post(url,
headers=self.build_header(),
params=params,
json=post_data,
timeout=30,
proxy=self.request_proxy) as resp:
data = await resp.json()
if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0]
message = choice["message"]["content"]
finish_reason = choice["finish_reason"]
prompt_tokens = int(data["usage"]["prompt_tokens"])
message_tokens = int(data["usage"]["completion_tokens"])
total_tokens = int(data["usage"]["total_tokens"])
return ChatCompleteResponse(message=message,
prompt_tokens=prompt_tokens,
message_tokens=message_tokens,
total_tokens=total_tokens,
finish_reason=finish_reason)
else:
print(data)
raise Exception("Invalid response from chat complete api")
return None
async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
tiktoken = await TikTokenService.create()
messageList = await self.make_message_list(question, system_prompt, conversation)
prompt_tokens = 0
for message in messageList:
prompt_tokens += await tiktoken.get_tokens(message["content"])
url = self.get_url("chat/completions")
params = {}
post_data = {
"messages": messageList,
"user": user,
"stream": True,
"n": 1,
"max_tokens": 768,
"stop": None,
"temperature": 1,
"top_p": 0.95
}
if self.api_type == "azure":
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
else:
post_data["model"] = model
post_data = {k: v for k, v in post_data.items() if v is not None}
res_message: list[str] = []
finish_reason = None
async with sse_client.EventSource(
url,
option={
"method": "POST"
},
headers=self.build_header(),
params=params,
json=post_data,
proxy=self.request_proxy
) as session:
async for event in session:
"""
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Thank"},"index":0,"finish_reason":null}]}
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
[DONE]
"""
content_started = False
event_data = event.data.strip()
if event_data == "[DONE]":
break
elif event_data[0] == "{" and event_data[-1] == "}":
data = json.loads(event_data)
if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0]
if choice["finish_reason"] is not None:
finish_reason = choice["finish_reason"]
delta_content = choice["delta"]
if "content" in delta_content:
delta_message: str = delta_content["content"]
# Skip empty lines before content
if not content_started:
if delta_message.replace("\n", "") == "":
continue
else:
content_started = True
res_message.append(delta_message)
# if config.DEBUG:
# print(delta_message, end="", flush=True)
if on_message is not None:
await on_message(delta_message)
if finish_reason is not None:
break
res_message_str = "".join(res_message)
message_tokens = await tiktoken.get_tokens(res_message_str)
total_tokens = prompt_tokens + message_tokens
return ChatCompleteResponse(message=res_message_str,
prompt_tokens=prompt_tokens,
message_tokens=message_tokens,
total_tokens=total_tokens,
finish_reason=finish_reason)