You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
328 lines
13 KiB
Python
328 lines
13 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
import json
|
|
from typing import Callable, Optional, TypedDict
|
|
|
|
import aiohttp
|
|
from libs.config import Config
|
|
import numpy as np
|
|
from aiohttp_sse_client2 import client as sse_client
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview"
|
|
AZURE_EMBEDDING_API_VERSION = "2023-05-15"
|
|
|
|
class OpenAIApiTypeInvalidException(Exception):
|
|
def __init__(self, api_id: str):
|
|
self.api_id = api_id
|
|
|
|
def __str__(self):
|
|
return f"Invalid api_id: {self.api_id}"
|
|
|
|
class ChatCompleteMessageLog(TypedDict):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatCompleteResponse(TypedDict):
|
|
message: str
|
|
prompt_tokens: int
|
|
message_tokens: int
|
|
total_tokens: int
|
|
finish_reason: str
|
|
|
|
class OpenAIApi:
|
|
@staticmethod
|
|
def create(api_id: str) -> OpenAIApi:
|
|
return OpenAIApi(api_id)
|
|
|
|
def __init__(self, api_id: str):
|
|
self.api_id = api_id
|
|
|
|
self.api_type = Config.get(f"chatcomplete.{api_id}.api_type", None, str)
|
|
|
|
if self.api_type is None:
|
|
raise OpenAIApiTypeInvalidException(api_id)
|
|
|
|
self.request_proxy = Config.get(f"chatcomplete.{api_id}.request_proxy", type=str, empty_is_none=True)
|
|
|
|
self.api_url = Config.get(f"chatcomplete.{api_id}.api_endpoint", type=str)
|
|
self.api_key = Config.get(f"chatcomplete.{api_id}.key", type=str)
|
|
|
|
def build_header(self):
|
|
if self.api_type == "azure":
|
|
return {
|
|
"content-type": "application/json",
|
|
"accept": "application/json",
|
|
"api-key": self.api_key
|
|
}
|
|
else:
|
|
return {
|
|
"authorization": f"Bearer {self.api_key}",
|
|
"content-type": "application/json",
|
|
"accept": "application/json",
|
|
}
|
|
|
|
def get_url(self, method: str):
|
|
if self.api_type == "azure":
|
|
if method == "chat/completions":
|
|
deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_chatcomplete")
|
|
if deployment is None:
|
|
raise Exception("deployment for chatcomplete is not set")
|
|
|
|
return self.api_url + "/openai/deployments/" + deployment + "/" + method
|
|
elif method == "embeddings":
|
|
deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_embedding")
|
|
if deployment is None:
|
|
raise Exception("deployment for embedding is not set")
|
|
|
|
return self.api_url + "/openai/deployments/" + deployment + "/" + method
|
|
else:
|
|
return self.api_url + "/v1/" + method
|
|
|
|
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
|
token_usage = 0
|
|
|
|
text_list = [doc["text"] for doc in doc_list]
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
url = self.get_url("embeddings")
|
|
params = {}
|
|
post_data = {
|
|
"input": text_list,
|
|
}
|
|
|
|
if self.api_type == "azure":
|
|
params["api-version"] = AZURE_EMBEDDING_API_VERSION
|
|
else:
|
|
post_data["model"] = "text-embedding-ada-002"
|
|
|
|
if self.api_type == "azure":
|
|
# Azure api does not support batch
|
|
for index, text in enumerate(text_list):
|
|
retry_num = 0
|
|
max_retry_num = 3
|
|
while retry_num < max_retry_num:
|
|
try:
|
|
async with session.post(url,
|
|
headers=self.build_header(),
|
|
params=params,
|
|
json={"input": text},
|
|
timeout=30,
|
|
proxy=self.request_proxy) as resp:
|
|
|
|
data = await resp.json()
|
|
|
|
one_data = data["data"]
|
|
if len(one_data) > 0:
|
|
embedding = one_data[0]["embedding"]
|
|
if embedding is not None:
|
|
embedding = np.array(embedding)
|
|
doc_list[index]["embedding"] = embedding
|
|
|
|
token_usage += int(data["usage"]["total_tokens"])
|
|
|
|
if on_index_progress is not None:
|
|
await on_index_progress(index, len(text_list))
|
|
|
|
break
|
|
except Exception as e:
|
|
retry_num += 1
|
|
|
|
if retry_num >= max_retry_num:
|
|
raise e
|
|
|
|
print("Error: %s" % e)
|
|
print("Retrying...")
|
|
await asyncio.sleep(0.5)
|
|
else:
|
|
async with session.post(url,
|
|
headers=self.build_header(),
|
|
params=params,
|
|
json=post_data,
|
|
timeout=30,
|
|
proxy=self.request_proxy) as resp:
|
|
|
|
data = await resp.json()
|
|
|
|
if "error" in data:
|
|
raise Exception(data["error"])
|
|
|
|
for one_data in data["data"]:
|
|
embedding = one_data["embedding"]
|
|
index = one_data["index"]
|
|
|
|
if index < len(doc_list):
|
|
if embedding is not None:
|
|
embedding = np.array(embedding)
|
|
doc_list[index]["embedding"] = embedding
|
|
|
|
token_usage = int(data["usage"]["total_tokens"])
|
|
|
|
if on_index_progress is not None:
|
|
await on_index_progress(index, len(text_list))
|
|
|
|
return (doc_list, token_usage)
|
|
|
|
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
|
|
summaryContent = None
|
|
|
|
messageList: list[ChatCompleteMessageLog] = []
|
|
for message in conversation:
|
|
if message["role"] == "summary":
|
|
summaryContent = message["content"]
|
|
elif message["role"] == "user" or message["role"] == "assistant":
|
|
messageList.append(message)
|
|
|
|
if summaryContent:
|
|
system_prompt += "\n\n" + summaryContent
|
|
|
|
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
|
|
messageList.append(ChatCompleteMessageLog(role="user", content=question))
|
|
|
|
return messageList
|
|
|
|
async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
|
|
messageList = await self.make_message_list(question, system_prompt, conversation)
|
|
|
|
url = self.get_url("chat/completions")
|
|
|
|
params = {}
|
|
post_data = {
|
|
"messages": messageList,
|
|
"user": user,
|
|
}
|
|
|
|
if self.api_type == "azure":
|
|
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
|
|
else:
|
|
post_data["model"] = model
|
|
|
|
post_data = {k: v for k, v in post_data.items() if v is not None}
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(url,
|
|
headers=self.build_header(),
|
|
params=params,
|
|
json=post_data,
|
|
timeout=30,
|
|
proxy=self.request_proxy) as resp:
|
|
|
|
data = await resp.json()
|
|
|
|
if "choices" in data and len(data["choices"]) > 0:
|
|
choice = data["choices"][0]
|
|
|
|
message = choice["message"]["content"]
|
|
finish_reason = choice["finish_reason"]
|
|
|
|
prompt_tokens = int(data["usage"]["prompt_tokens"])
|
|
message_tokens = int(data["usage"]["completion_tokens"])
|
|
total_tokens = int(data["usage"]["total_tokens"])
|
|
|
|
return ChatCompleteResponse(message=message,
|
|
prompt_tokens=prompt_tokens,
|
|
message_tokens=message_tokens,
|
|
total_tokens=total_tokens,
|
|
finish_reason=finish_reason)
|
|
else:
|
|
print(data)
|
|
raise Exception("Invalid response from chat complete api")
|
|
|
|
return None
|
|
|
|
async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
|
|
tiktoken = await TikTokenService.create()
|
|
|
|
messageList = await self.make_message_list(question, system_prompt, conversation)
|
|
|
|
prompt_tokens = 0
|
|
for message in messageList:
|
|
prompt_tokens += await tiktoken.get_tokens(message["content"])
|
|
|
|
url = self.get_url("chat/completions")
|
|
|
|
params = {}
|
|
post_data = {
|
|
"messages": messageList,
|
|
"user": user,
|
|
"stream": True,
|
|
"n": 1,
|
|
"max_tokens": 768,
|
|
"stop": None,
|
|
"temperature": 1,
|
|
"top_p": 0.95
|
|
}
|
|
|
|
if self.api_type == "azure":
|
|
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
|
|
else:
|
|
post_data["model"] = model
|
|
|
|
post_data = {k: v for k, v in post_data.items() if v is not None}
|
|
|
|
res_message: list[str] = []
|
|
finish_reason = None
|
|
|
|
async with sse_client.EventSource(
|
|
url,
|
|
option={
|
|
"method": "POST"
|
|
},
|
|
headers=self.build_header(),
|
|
params=params,
|
|
json=post_data,
|
|
proxy=self.request_proxy
|
|
) as session:
|
|
async for event in session:
|
|
"""
|
|
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
|
|
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Thank"},"index":0,"finish_reason":null}]}
|
|
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
|
|
[DONE]
|
|
"""
|
|
content_started = False
|
|
|
|
event_data = event.data.strip()
|
|
|
|
if event_data == "[DONE]":
|
|
break
|
|
elif event_data[0] == "{" and event_data[-1] == "}":
|
|
data = json.loads(event_data)
|
|
if "choices" in data and len(data["choices"]) > 0:
|
|
choice = data["choices"][0]
|
|
|
|
if choice["finish_reason"] is not None:
|
|
finish_reason = choice["finish_reason"]
|
|
|
|
delta_content = choice["delta"]
|
|
if "content" in delta_content:
|
|
delta_message: str = delta_content["content"]
|
|
|
|
# Skip empty lines before content
|
|
if not content_started:
|
|
if delta_message.replace("\n", "") == "":
|
|
continue
|
|
else:
|
|
content_started = True
|
|
|
|
res_message.append(delta_message)
|
|
|
|
# if config.DEBUG:
|
|
# print(delta_message, end="", flush=True)
|
|
|
|
if on_message is not None:
|
|
await on_message(delta_message)
|
|
|
|
if finish_reason is not None:
|
|
break
|
|
|
|
res_message_str = "".join(res_message)
|
|
message_tokens = await tiktoken.get_tokens(res_message_str)
|
|
total_tokens = prompt_tokens + message_tokens
|
|
|
|
return ChatCompleteResponse(message=res_message_str,
|
|
prompt_tokens=prompt_tokens,
|
|
message_tokens=message_tokens,
|
|
total_tokens=total_tokens,
|
|
finish_reason=finish_reason) |