增加noawait,支持Azure API

master
落雨楓 2 years ago
parent e21a28a85f
commit 2f68357c1d

@ -0,0 +1,4 @@
#!/bin/sh
DIRNAME=`dirname $0`
cd $DIRNAME
./.venv/bin/activate

@ -2,80 +2,29 @@ import asyncio
import json
import time
import traceback
from local import noawait
from typing import Optional
from aiohttp import WSMsgType, web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.tiktoken import TikTokenService
import utils.web
class ChatCompleteTaskList:
def __init__(self, dbs: DatabaseService):
self.on_message = None
self.chunks: list[str] = []
class ChatCompleteWebSocketController:
def __init__(self, request: web.Request):
self.request = request
self.ws = None
self.db = None
self.chat_complete = None
self.closed = False
self.refreshed_time = 0
async def run(self):
self.ws = web.WebSocketResponse()
await self.ws.prepare(self.request)
self.refreshed_time = time.time()
self.db = await DatabaseService.create(self.request.app)
self.query = self.request.query
if self.request.get("caller") == "user":
user_id = self.request.get("user")
else:
user_id = self.query.get("user_id")
title = self.query.get("title")
# create heartbeat task
asyncio.ensure_future(self._timeout_task())
async for msg in self.ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
event = data.get('event')
self.refreshed_time = time.time()
if event == 'chatcomplete':
asyncio.ensure_future(self._chatcomplete(data))
if event == 'ping':
await self.ws.send_json({
'event': 'pong'
})
except Exception as e:
print(e)
traceback.print_exc()
await self.ws.send_json({
'event': 'error',
'error': str(e)
})
elif msg.type == WSMsgType.ERROR:
print('ws connection closed with exception %s' %
self.ws.exception())
async def _timeout_task(self):
while not self.closed:
if time.time() - self.refreshed_time > 30:
self.closed = True
await self.ws.close()
return
await asyncio.sleep(1)
async def _chatcomplete(self, params: dict):
question = params.get("question")
conversation_id = params.get("conversation_id")
async def run():
pass
@noawait.wrap
async def start(self):
await self.run()
class ChatComplete:
@staticmethod
@ -244,8 +193,11 @@ class ChatComplete:
tokens = await tiktoken.get_tokens(question)
transatcion_id = None
point_cost = 0
if request.get("caller") == "user":
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
transatcion_id = usage_res.get("transaction_id")
point_cost = usage_res.get("point_cost")
async def on_message(text: str):
# Send message to client, start with "+" to indicate it's a message
@ -261,8 +213,7 @@ class ChatComplete:
try:
chat_res = await chat_complete_service \
.chat_complete(question, on_message, on_extracted_doc,
conversation_id=conversation_id, user_id=user_id, embedding_search={
.prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={
"limit": extra_limit,
"in_collection": in_collection,
})
@ -272,6 +223,8 @@ class ChatComplete:
**chat_res,
})
await chat_complete_service.set_latest_point_cost(point_cost)
if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
except Exception as e:

@ -31,7 +31,8 @@ class EmbeddingSearch:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
usage_res = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index()
@ -107,7 +108,8 @@ class EmbeddingSearch:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
usage_res = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index()

@ -13,7 +13,7 @@ class ConversationChunkModel(BaseModel):
__tablename__ = "chat_complete_conversation_chunk"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("chat_complete_conversation.id"), index=True)
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True)
message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True)

@ -1,12 +1,13 @@
from __future__ import annotations
import hashlib
from typing import Optional
from typing import Optional, Type
import asyncpg
from api.model.base import BaseModel
import config
import numpy as np
import sqlalchemy
from sqlalchemy import select, update, delete
from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession
from pgvector.asyncpg import register_vector
@ -14,19 +15,38 @@ from pgvector.sqlalchemy import Vector
from service.database import DatabaseService
class PageIndexModel(BaseModel):
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
class AbstractPageIndexModel(BaseModel):
__abstract__ = True
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
text: Mapped[str] = mapped_column(sqlalchemy.Text)
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)
markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
def create_page_index_model(collection_id: int):
if collection_id in page_index_model_list:
return page_index_model_list[collection_id]
else:
class PageIndexModel(AbstractPageIndexModel):
__tablename__ = "embedding_search_page_index_%s" % str(collection_id)
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
page_index_model_list[collection_id] = PageIndexModel
return PageIndexModel
class PageIndexHelper:
columns = [
"id",
@ -60,12 +80,19 @@ class PageIndexHelper:
await register_vector(self.dbi)
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.orm = create_page_index_model(self.collection_id)
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb)
await self.session.__aexit__(exc_type, exc, tb)
async def table_exists(self):
exists = await self.dbi.fetchval("""SELECT EXISTS (
@ -83,27 +110,31 @@ class PageIndexHelper:
# create table if not exists
if not await self.table_exists():
await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
id SERIAL PRIMARY KEY,
page_id INTEGER NOT NULL,
sha1 VARCHAR(40) NOT NULL,
text TEXT NOT NULL,
text_len INTEGER NOT NULL,
embedding VECTOR(%d) NOT NULL,
markdown TEXT NULL,
markdown_len INTEGER NULL,
temp_doc_session_id INTEGER NULL
);
CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
""" % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
self.table_initialized = False
async with self.dbs.engine.begin() as conn:
await conn.run_sync(self.orm.__table__.create)
# await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
# id SERIAL PRIMARY KEY,
# page_id INTEGER NOT NULL,
# sha1 VARCHAR(40) NOT NULL,
# text TEXT NOT NULL,
# text_len INTEGER NOT NULL,
# embedding VECTOR(%d) NOT NULL,
# markdown TEXT NULL,
# markdown_len INTEGER NULL,
# temp_doc_session_id INTEGER NULL
# );
# CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
# CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
self.table_initialized = True
async def create_embedding_index(self):
await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
.replace("/*_*/", self.table_name))
pass
# await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
# .replace("/*_*/", self.table_name))
def sha1_doc(self, doc: list):
for item in doc:
@ -113,25 +144,20 @@ class PageIndexHelper:
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
indexed_sha1_list = []
sql = "SELECT sha1 FROM %s" % (self.table_name)
where = []
params = []
stmt = select(self.orm).column(self.orm.sha1)
if not with_temporary:
where.append("temp_doc_session_id IS NULL")
stmt = stmt.where(self.orm.temp_doc_session_id == None)
if not in_collection:
params.append(self.page_id)
where.append("page_id = $%d" % len(params))
stmt = stmt.where(self.orm.page_id == self.page_id)
if len(where) > 0:
sql += " WHERE " + (" AND ".join(where))
ret = await self.dbi.fetch(sql, *params)
ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
for row in ret:
indexed_sha1_list.append(row[0])
indexed_sha1_list.append(row.sha1)
return indexed_sha1_list
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):

@ -11,7 +11,7 @@ class TitleCollectionModel(BaseModel):
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
class TitleCollectionHelper:
def __init__(self, dbs: DatabaseService):
@ -29,7 +29,6 @@ class TitleCollectionHelper:
async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb)
pass
async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]:
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)

@ -20,7 +20,11 @@ class TitleIndexModel(BaseModel):
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
class TitleIndexHelper:
__tablename__ = "embedding_search_title_index"

@ -14,8 +14,15 @@ DATABASE = {
EMBEDDING_VECTOR_SIZE = 1536
OPENAI_API_TYPE = "openai" # openai or azure
OPENAI_API = "https://api.openai.com"
OPENAI_TOKEN = "sk-"
OPENAI_API = None
OPENAI_TOKEN = ""
AZURE_OPENAI_ENDPOINT = "https://your-instance.openai.azure.com"
AZURE_OPENAI_KEY = ""
AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME = ""
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME = ""
CHATCOMPLETE_MAX_MEMORY_TOKENS = 1024
CHATCOMPLETE_MAX_INPUT_TOKENS = 768

@ -0,0 +1,5 @@
import asyncio
from noawait import NoAwaitPool
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -1,13 +1,18 @@
import asyncio
from typing import TypedDict
from local import loop, noawait
from aiohttp import web
import asyncpg
import config
import api.route
import utils.web
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
# Auto create Table
from api.model.base import BaseModel
from api.model.toolkit_ui.conversation import ConversationModel as _
from api.model.chat_complete.conversation import ConversationChunkModel as _
from api.model.embedding_search.title_collection import TitleCollectionModel as _
from api.model.embedding_search.title_index import TitleIndexModel as _
from service.tiktoken import TikTokenService
@ -17,7 +22,10 @@ async def index(request: web.Request):
async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create()
if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD:
try:
await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD)
except Exception as e:
print("Cannot login to Robot account, please check config.")
site_meta = await mw_api.get_site_meta()
@ -27,13 +35,17 @@ async def init_database(app: web.Application):
dbs = await DatabaseService.create(app)
print("Database connected.")
async with dbs.engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.create_all)
async def init_tiktoken(app: web.Application):
await TikTokenService.create()
print("Tiktoken model loaded.")
if __name__ == '__main__':
loop = asyncio.get_event_loop()
async def stop_noawait_pool(app: web.Application):
await noawait.end()
if __name__ == '__main__':
app = web.Application()
if config.DATABASE:
@ -45,7 +57,9 @@ if __name__ == '__main__':
if config.OPENAI_TOKEN:
app.on_startup.append(init_tiktoken)
app.on_shutdown.append(stop_noawait_pool)
app.router.add_route('*', '/', index)
api.route.init(app)
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop)
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop)

@ -0,0 +1,72 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
from functools import wraps
import sys
import traceback
from typing import Callable, Coroutine
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.loop = loop
self.running = True
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
async def end(self):
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
coroutine = f(*args, **kwargs)
self.add_task(coroutine)
return decorated_function
async def _finish_task(self, task: Task):
try:
if not task.done():
task.cancel()
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
if handler_ret is Coroutine:
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
should_remove = []
for task in self.task_list:
if task.done():
await self._finish_task(task)
should_remove.append(task)
for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)

@ -14,3 +14,4 @@ sqlalchemy==2.0.9
aiohttp-sse-client2==0.3.0
OpenCC==1.1.6
event-emitter-asyncio==1.0.4
tiktoken-async==0.3.2

@ -19,6 +19,11 @@ from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
@ -44,9 +49,18 @@ class ChatCompleteService:
self.tiktoken: TikTokenService = None
self.extract_doc: list = None
self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.user_id = 0
self.question = ""
self.question_tokens: Optional[int] = None
self.conversation_id: Optional[int] = None
self.delta_data = {}
async def __aenter__(self):
self.tiktoken = await TikTokenService.create()
@ -67,26 +81,55 @@ class ChatCompleteService:
async def get_question_tokens(self, question: str):
return await self.tiktoken.get_tokens(question)
async def chat_complete(self, question: str, on_message: Optional[callable] = None, on_extracted_doc: Optional[callable] = None,
conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None,
async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse:
if user_id is not None:
user_id = int(user_id)
self.user_id = user_id
self.question = question
self.conversation_info = None
if conversation_id is not None:
conversation_id = int(conversation_id)
self.conversation_info = await self.conversation_helper.get_conversation(conversation_id)
self.conversation_id = int(conversation_id)
self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id)
else:
self.conversation_id = None
if self.conversation_info is not None:
if self.conversation_info.user_id != user_id:
raise web.HTTPUnauthorized()
if question_tokens is None:
self.question_tokens = await self.get_question_tokens(question)
else:
self.question_tokens = question_tokens
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
# Extract document from wiki page index
self.extract_doc = None
if embedding_search is not None:
self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
if self.extract_doc is not None:
self.question_tokens += token_usage
return ChatCompleteServicePrepareResponse(
extract_doc=self.extract_doc,
question_tokens=self.question_tokens
)
async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse:
delta_data = {}
self.conversation_chunk = None
message_log = []
if self.conversation_info is not None:
if self.conversation_info.user_id != user_id:
raise web.HTTPUnauthorized()
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(conversation_id)
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id)
# If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
@ -95,9 +138,9 @@ class ChatCompleteService:
{"role": "summary", "content": summary, "tokens": tokens}
]
self.conversation_chunk = await self.conversation_chunk_helper.add(conversation_id, new_message_log, tokens)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens)
delta_data["conversation_chunk_id"] = self.conversation_chunk.id
self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id
message_log = []
for message in self.conversation_chunk.message_data:
@ -106,23 +149,9 @@ class ChatCompleteService:
"content": message["content"],
})
if question_tokens is None:
question_tokens = await self.get_question_tokens(question)
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
extract_doc = None
if embedding_search is not None:
extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
if extract_doc is not None:
if on_extracted_doc is not None:
await on_extracted_doc(extract_doc)
question_tokens = token_usage
if self.extract_doc is not None:
doc_prompt_content = "\n".join(["%d. %s" % (
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(extract_doc)])
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)])
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", {
"content": doc_prompt_content})
@ -132,14 +161,14 @@ class ChatCompleteService:
# Start chat complete
if on_message is not None:
response = await self.openai_api.chat_complete_stream(question, system_prompt, message_log, on_message)
response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message)
else:
response = await self.openai_api.chat_complete(question, system_prompt, message_log)
response = await self.openai_api.chat_complete(self.question, system_prompt, message_log)
if self.conversation_info is None:
# Create a new conversation
message_log_list = [
{"role": "user", "content": question, "tokens": question_tokens},
{"role": "user", "content": self.question, "tokens": self.question_tokens},
{"role": "assistant",
"content": response["message"], "tokens": response["message_tokens"]},
]
@ -152,21 +181,21 @@ class ChatCompleteService:
print(str(e), file=sys.stderr)
traceback.print_exc(file=sys.stderr)
total_token_usage = question_tokens + response["message_tokens"]
total_token_usage = self.question_tokens + response["message_tokens"]
title_info = self.embedding_search.title_info
self.conversation_info = await self.conversation_helper.add(user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage)
else:
# Update the conversation chunk
await self.conversation_helper.refresh_updated_at(conversation_id)
await self.conversation_helper.refresh_updated_at(self.conversation_id)
self.conversation_chunk.message_data.append(
{"role": "user", "content": question, "tokens": question_tokens})
{"role": "user", "content": self.question, "tokens": self.question_tokens})
self.conversation_chunk.message_data.append(
{"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]})
flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens += question_tokens + \
self.conversation_chunk.tokens += self.question_tokens + \
response["message_tokens"]
await self.conversation_chunk_helper.update(self.conversation_chunk)

@ -1,5 +1,5 @@
from __future__ import annotations
import asyncio
import local
from urllib.parse import quote_plus
from aiohttp import web
import asyncpg
@ -38,7 +38,7 @@ class DatabaseService:
self.create_session: async_sessionmaker[AsyncSession] = None
async def init(self):
loop = asyncio.get_event_loop()
loop = local.loop
self.pool = asyncpg.create_pool(**config.DATABASE, loop=loop)
await self.pool.__aenter__()

@ -129,10 +129,24 @@ class EmbeddingSearchService:
if self.unindexed_docs is None:
return False
chunk_limit = 500
chunk_len = 0
doc_chunk = []
total_token_usage = 0
processed_len = 0
async def on_embedding_progress(current, length):
nonlocal processed_len
indexed_docs = processed_len + current
if on_progress is not None:
await on_progress(indexed_docs, len(self.unindexed_docs))
async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk, on_embedding_progress)
await self.page_index.index_doc(doc_chunk)
return token_usage
@ -141,11 +155,6 @@ class EmbeddingSearchService:
if on_progress is not None:
await on_progress(0, len(self.unindexed_docs))
chunk_limit = 500
chunk_len = 0
processed_len = 0
doc_chunk = []
for doc in self.unindexed_docs:
chunk_len += len(doc)

@ -1,7 +1,7 @@
import json
import sys
import time
from typing import Optional
from typing import Optional, TypedDict
import aiohttp
import config
@ -18,6 +18,10 @@ class MediaWikiApiException(Exception):
class MediaWikiPageNotFoundException(MediaWikiApiException):
pass
class ChatCompleteReportUsageResponse(TypedDict):
point_cost: int
transaction_id: str
class MediaWikiApi:
cookie_jar = aiohttp.CookieJar(unsafe=True)
@ -27,7 +31,7 @@ class MediaWikiApi:
def __init__(self, api_url: str):
self.api_url = api_url
self.login_time = 0
self.login_time = 0.0
self.login_identity = None
async def get_page_info(self, title: str):
@ -142,7 +146,7 @@ class MediaWikiApi:
async def refresh_login(self):
if self.login_identity is None:
return False
if time.time() - self.login_time > 10:
if time.time() - self.login_time > 30:
return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
async def chat_complete_user_info(self, user_id: int):
@ -166,7 +170,7 @@ class MediaWikiApi:
return data["chatcompletebot"]["userinfo"]
async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> str:
async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse:
await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -189,7 +193,8 @@ class MediaWikiApi:
print(data)
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return data["chatcompletebot"]["reportusage"]["transactionid"]
return ChatCompleteReportUsageResponse(point_cost=int(data["chatcompletebot"]["reportusage"]["pointcost"]),
transaction_id=data["chatcompletebot"]["reportusage"]["transactionid"])
async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
await self.refresh_login()

@ -1,6 +1,6 @@
from __future__ import annotations
import json
from typing import TypedDict
from typing import Callable, Optional, TypedDict
import aiohttp
import config
@ -23,23 +23,100 @@ class ChatCompleteResponse(TypedDict):
class OpenAIApi:
@staticmethod
def create():
return OpenAIApi(config.OPENAI_API or "https://api.openai.com", config.OPENAI_TOKEN)
return OpenAIApi()
def __init__(self, api_url: str, token: str):
self.api_url = api_url
self.token = token
def __init__(self):
if config.OPENAI_API_TYPE == "azure":
self.api_url = config.AZURE_OPENAI_ENDPOINT
self.api_key = config.AZURE_OPENAI_KEY
else:
self.api_url = config.OPENAI_API or "https://api.openai.com"
self.api_key = config.OPENAI_TOKEN
def build_header(self):
if config.OPENAI_API_TYPE == "azure":
return {
"Content-Type": "application/json",
"Accept": "application/json",
"api-key": self.api_key
}
else:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure":
if method == "completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method
elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method
else:
return self.api_url + "/v1/" + method
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
text_list = []
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
text_list.append(text)
async def get_embeddings(self, doc_list: list):
token_usage = 0
async with aiohttp.ClientSession() as session:
text_list = [doc["text"] for doc in doc_list]
params = {
"model": "text-embedding-ada-002",
url = self.get_url("embeddings")
params = {}
post_data = {
"input": text_list,
}
async with session.post(self.api_url + "/v1/embeddings",
headers={"Authorization": f"Bearer {self.token}"},
json=params,
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "text-embedding-ada-002"
if config.OPENAI_API_TYPE == "azure":
# Azure api does not support batch
for index, text in enumerate(text_list):
async with session.post(url,
headers=self.build_header(),
params=params,
json={"input": text},
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
one_data = data["data"]
if len(one_data) > 0:
embedding = one_data[0]["embedding"]
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage += int(data["usage"]["total_tokens"])
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
else:
async with session.post(url,
headers=self.build_header(),
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
@ -56,6 +133,8 @@ class OpenAIApi:
token_usage = int(data["usage"]["total_tokens"])
await on_index_progress(index, len(text_list))
return (doc_list, token_usage)
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
@ -79,17 +158,26 @@ class OpenAIApi:
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation)
params = {
"model": "gpt-3.5-turbo",
url = self.get_url("completions")
params = {}
post_data = {
"messages": messageList,
"user": user,
}
params = {k: v for k, v in params.items() if v is not None}
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
post_data = {k: v for k, v in post_data.items() if v is not None}
async with aiohttp.ClientSession() as session:
async with session.post(self.api_url + "/v1/chat/completions",
headers={"Authorization": f"Bearer {self.token}"},
json=params,
async with session.post(url,
headers=self.build_header,
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
@ -138,7 +226,7 @@ class OpenAIApi:
option={
"method": "POST"
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": f"Bearer {self.api_key}"},
json=params,
proxy=config.REQUEST_PROXY
) as session:

@ -1,6 +1,4 @@
import asyncio
import config
import asyncpg
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
@ -9,16 +7,20 @@ async def main():
dbs = await DatabaseService.create()
async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search:
await embedding_search.prepare_update_index()
async def on_index_progress(current, length):
print("索引进度:%.1f%%" % (current / length * 100))
print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True)
print("")
await embedding_search.update_page_index(on_index_progress)
print("")
while True:
query = input("请输入要搜索的问题 (.exit 退出)")
if query == ".exit":
break
res = await embedding_search.search(query, 5)
res, token_usage = await embedding_search.search(query, 5)
total_length = 0
if res:
for one in res:
@ -29,5 +31,7 @@ async def main():
print("总长度:%d" % total_length)
await local.noawait.end()
if __name__ == '__main__':
asyncio.run(main())
local.loop.run_until_complete(main())
Loading…
Cancel
Save