更改流式输出模式

master
落雨楓 2 years ago
parent 6a74050b56
commit 68d1647d5d

@ -1,19 +1,19 @@
from __future__ import annotations
import asyncio
import json
import sys
import time
import traceback
from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait
from typing import Optional, Callable, TypedDict
from aiohttp import web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService
from service.chat_complete import ChatCompleteService, ChatCompleteServiceResponse
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService
import utils.web
@ -23,6 +23,7 @@ class ChatCompleteTask:
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = []
@ -37,7 +38,12 @@ class ChatCompleteTask:
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
async def init(self, question: str, conversation_id: Optional[str] = None,
self.is_finished = False
self.finished_time: Optional[float] = None
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None):
self.tiktoken = await TikTokenService.create()
@ -54,34 +60,58 @@ class ChatCompleteTask:
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit)
self.transatcion_id = usage_res.get("transaction_id")
self.point_cost = usage_res.get("point_cost")
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete",
question_tokens, extract_limit)
self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, embedding_search=embedding_search)
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
return chat_res
return res
else:
await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message:
try:
await callback(delta_message)
except Exception as e:
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error:
try:
await callback(err)
except Exception as e:
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def run(self):
try:
chat_res = self.chat_complete.finish_chat_complete(self._on_message)
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
self.result = chat_res
if self.transatcion_id:
result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
await self._on_finished()
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
@ -89,7 +119,7 @@ class ChatCompleteTask:
traceback.print_exc()
if self.transatcion_id:
result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e)
finally:
@ -98,10 +128,19 @@ class ChatCompleteTask:
async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()
TASK_EXPIRE_TIME = 60 * 10
@noawait.wrap
async def start(self):
await self.run()
async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)
class ChatComplete:
@staticmethod
@ -112,7 +151,7 @@ class ChatComplete:
"required": False,
"type": int
},
"conversation_id": {
"id": {
"required": True,
"type": int
}
@ -123,40 +162,31 @@ class ChatComplete:
else:
user_id = params.get("user_id")
conversation_id = params.get("conversation_id")
conversation_id = params.get("id")
db = await DatabaseService.create(request.app)
async with db.create_session() as session:
stmt = select(ConversationModel).where(
ConversationModel.id == conversation_id)
conversation_data = await session.scalar(stmt)
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
conversation_info = await conversation_helper.find_by_id(conversation_id)
if conversation_data is None:
if conversation_info is None:
return await utils.web.api_response(-1, error={
"code": "conversation-not-found",
"message": "Conversation not found."
}, http_status=404, request=request)
if conversation_data.user_id != user_id:
if conversation_info.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
conversation_chunk_result = await session.scalars(stmt)
conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id)
conversation_chunk_list = []
for result in conversation_chunk_result:
conversation_chunk_list.append({
"id": result.id,
"updated_at": result.updated_at
})
conversation_chunk_list.append(result)
return await utils.web.api_response(1, conversation_chunk_list, request=request)
@ -181,26 +211,32 @@ class ChatComplete:
chunk_id = params.get("chunk_id")
dbs = await DatabaseService.create(request.app)
async with dbs.create_session() as session:
stmt = select(ConversationChunkModel).where(
ConversationChunkModel.id == chunk_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
if chunk_info is None:
return await utils.web.api_response(-1, error={
"code": "conversation-chunk-not-found",
"message": "Conversation chunk not found."
}, http_status=404, request=request)
if conversation_data.conversation.user_id != user_id:
conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id)
if conversation_info is not None and conversation_info.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
chunk_dict = {
"id": chunk_info.id,
"conversation_id": chunk_info.conversation_id,
"message_data": chunk_info.message_data,
"tokens": chunk_info.tokens,
"updated_at": chunk_info.updated_at,
}
return await utils.web.api_response(1, chunk_dict, request=request)
@staticmethod
@utils.web.token_auth
@ -219,6 +255,47 @@ class ChatComplete:
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod
@utils.web.token_auth
async def get_point_cost(request: web.Request):
params = await utils.web.get_param(request, {
"question": {
"type": str,
"required": True,
},
"extract_limit": {
"type": int,
"required": False,
"default": 10,
},
})
user_id = request.get("user")
caller = request.get("caller")
question = params.get("question")
extract_limit = params.get("extract_limit")
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create()
tokens = await tiktoken.get_tokens(question)
try:
res = await mwapi.chat_complete_get_point_cost(user_id, "chatcomplete", tokens, extract_limit)
return await utils.web.api_response(1, {
"point_cost": res["point_cost"],
"tokens": tokens,
}, request=request)
except Exception as e:
err_msg = f"Error while get chat complete point cost: {e}"
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "chat-complete-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def start_chat_complete(request: web.Request):
@ -245,6 +322,10 @@ class ChatComplete:
"required": False,
"default": False,
},
"edit_message_id": {
"type": str,
"required": False,
},
})
user_id = request.get("user")
@ -257,20 +338,25 @@ class ChatComplete:
extract_limit = params.get("extract_limit")
in_collection = params.get("in_collection")
edit_message_id = params.get("edit_message_id")
dbs = await DatabaseService.create(request.app)
try:
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id,
embedding_search={
"limit": extract_limit,
"in_collection": in_collection,
})
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
chat_complete_task.start()
noawait.add_task(chat_complete_task.run())
return utils.web.api_response(1, data={
return await utils.web.api_response(1, data={
"conversation_id": init_res["conversation_id"],
"chunk_id": init_res["chunk_id"],
"question_tokens": init_res["question_tokens"],
"extract_doc": init_res["extract_doc"],
"task_id": chat_complete_task.task_id,
@ -282,6 +368,12 @@ class ChatComplete:
"title": page_title,
"message": error_msg
}, http_status=404, request=request)
except MediaWikiUserNoEnoughPointsException as e:
error_msg = "Does not have enough points." % user_id
return await utils.web.api_response(-1, error={
"code": "no-enough-points",
"message": error_msg
}, http_status=403, request=request)
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
@ -294,4 +386,123 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def chat_complete_stream(request: web.Request):
pass
if not utils.web.is_websocket(request):
return await utils.web.api_response(-1, error={
"code": "websocket-required",
"message": "This API only accept websocket connection."
}, http_status=400, request=request)
params = await utils.web.get_param(request, {
"task_id": {
"type": str,
"required": True,
}
})
ws = web.WebSocketResponse()
await ws.prepare(request)
task_id = params.get("task_id")
task = chat_complete_tasks.get(task_id)
if task is None:
await ws.send_json({
'event': 'error',
'status': -1,
'message': "Task not found.",
'error': {
'code': "task-not-found",
'info': "Task not found.",
},
})
return
if request.get("caller") == "user":
user_id = request.get("user")
if task.user_id != user_id:
await ws.send_json({
'event': 'error',
'status': -1,
'message': "Permission denied.",
'error': {
'code': "permission-denied",
'info': "Permission denied.",
},
})
return
if task.is_finished:
if task.error is not None:
await ws.send_json({
'event': 'error',
'status': -1,
'message': str(task.error),
'error': {
'code': "internal-error",
'info': str(task.error),
},
})
await ws.close()
elif task.result is not None:
await ws.send_json({
'event': 'connected',
'status': 1,
'outputed_message': "".join(task.chunks),
})
await ws.send_json({
'event': 'finished',
'status': 1,
'result': task.result
})
await ws.close()
else:
async def on_message(delta_message: str):
await ws.send_str("+" + delta_message)
async def on_finished(result: ChatCompleteServiceResponse):
ignored_keys = ["message"]
response_result = {
"point_cost": task.point_cost,
}
for k, v in result.items():
if k not in ignored_keys:
response_result[k] = v
await ws.send_json({
'event': 'finished',
'status': 1,
'result': response_result
})
await ws.close()
async def on_error(err: Exception):
await ws.send_json({
'event': 'error',
'status': -1,
'message': str(err),
'error': {
'code': "internal-error",
'info': str(err),
},
})
await ws.close()
task.on_message.append(on_message)
task.on_finished.append(on_finished)
task.on_error.append(on_error)
# Send received message
await ws.send_json({
'event': 'connected',
'status': 1,
'outputed_message': "".join(task.chunks),
})
while True:
if ws.closed:
task.on_message.remove(on_message)
task.on_finished.remove(on_finished)
task.on_error.remove(on_error)
break
await asyncio.sleep(0.1)

@ -2,7 +2,7 @@ import sys
import traceback
from aiohttp import web
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
import utils.web
@ -87,6 +87,18 @@ class EmbeddingSearch:
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
except EmbeddingRunningException:
error_msg = "Page index is running now"
await ws.send_json({
'event': 'error',
'status': -4,
'message': error_msg,
'error': {
'code': 'page_index_running',
},
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
@ -94,7 +106,10 @@ class EmbeddingSearch:
await ws.send_json({
'event': 'error',
'status': -1,
'message': error_msg
'message': error_msg,
'error': {
'code': 'internal_server_error',
}
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
@ -145,6 +160,16 @@ class EmbeddingSearch:
"info": e.info,
"message": error_msg
}, http_status=500)
except EmbeddingRunningException:
error_msg = "Page index is running now"
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-4, error={
"code": "page-index-running",
"message": error_msg
}, http_status=429)
except Exception as e:
error_msg = str(e)

@ -2,7 +2,6 @@ import sys
import time
import traceback
from aiohttp import web
from sqlalchemy import select
from api.model.toolkit_ui.conversation import ConversationHelper
from api.model.toolkit_ui.page_title import PageTitleHelper
from service.database import DatabaseService
@ -123,6 +122,7 @@ class Index:
"id": result.id,
"module": result.module,
"title": result.title,
"description": result.description,
"thumbnail": result.thumbnail,
"rev_id": result.rev_id,
"updated_at": result.updated_at,
@ -166,6 +166,7 @@ class Index:
"id": conversation_info.id,
"module": conversation_info.module,
"title": conversation_info.title,
"description": conversation_info.description,
"thumbnail": conversation_info.thumbnail,
"rev_id": conversation_info.rev_id,
"updated_at": conversation_info.updated_at,
@ -178,14 +179,68 @@ class Index:
@staticmethod
@utils.web.token_auth
async def remove_conversation(request: web.Request):
params = await utils.web.get_param(request, {
"id": {
"type": int
},
"ids": {
"type": str
}
})
conversation_id = params.get("id")
conversation_ids = params.get("ids")
if conversation_id is None and conversation_ids is None:
return await utils.web.api_response(-2, error={
"code": "invalid-params",
"message": "Invalid params."
}, request=request, http_status=400)
if conversation_id is not None:
conversation_ids = [conversation_id]
else:
conversation_ids = conversation_ids.split(",")
conversation_ids = [int(id) for id in conversation_ids]
db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper:
user_id = None
if request.get("caller") == "user":
user_id = int(request.get("user"))
conversation_ids = await conversation_helper.filter_user_owned_ids(conversation_ids, user_id=user_id)
if len(conversation_ids) > 0:
await conversation_helper.remove_multiple(conversation_ids)
# 通知其他模块删除
events = EventService.create()
events.emit("conversation/removed", {
"ids": conversation_ids,
"dbs": db,
"app": request.app,
})
return await utils.web.api_response(1, data={
"count": len(conversation_ids)
}, request=request)
@staticmethod
@utils.web.token_auth
async def set_conversation_pinned(request: web.Request):
params = await utils.web.get_param(request, {
"id": {
"required": True,
"type": int
},
"pinned": {
"required": True,
"type": bool
}
})
conversation_id = params.get("id")
pinned = params.get("pinned")
db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper:
@ -203,39 +258,27 @@ class Index:
"message": "Permission denied."
}, request=request, http_status=403)
await conversation_helper.remove(conversation_info)
# 通知其他模块删除
events = EventService.create()
events.emit("conversation/removed", {
"conversation": conversation_info,
"dbs": db,
"app": request.app,
})
events.emit("conversation/removed/" + conversation_info.module, {
"conversation": conversation_info,
"dbs": db,
"app": request.app,
})
conversation_info.pinned = pinned
await conversation_helper.update(conversation_info)
return await utils.web.api_response(1, request=request)
@staticmethod
@utils.web.token_auth
async def set_conversation_pinned(request: web.Request):
async def set_conversation_title(request: web.Request):
params = await utils.web.get_param(request, {
"id": {
"required": True,
"type": int
},
"pinned": {
"new_title": {
"required": True,
"type": bool
"type": str
}
})
conversation_id = params.get("id")
pinned = params.get("pinned")
new_title = params.get("new_title")
db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper:
@ -253,7 +296,7 @@ class Index:
"message": "Permission denied."
}, request=request, http_status=403)
conversation_info.pinned = pinned
conversation_info.title = new_title
await conversation_helper.update(conversation_info)
return await utils.web.api_response(1, request=request)

@ -1,7 +1,8 @@
from __future__ import annotations
import time
import sqlalchemy
from sqlalchemy import update
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
from api.model.base import BaseModel
@ -13,10 +14,11 @@ class ConversationChunkModel(BaseModel):
__tablename__ = "chat_complete_conversation_chunk"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True)
conversation_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(ConversationModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
class ConversationChunkHelper:
def __init__(self, dbs: DatabaseService):
@ -36,52 +38,58 @@ class ConversationChunkHelper:
await self.session.__aexit__(exc_type, exc, tb)
pass
async def add(self, conversation_id: int, message_data: list, tokens: int):
async with self.create_session() as session:
chunk = ConversationChunkModel(
conversation_id=conversation_id,
message_data=message_data,
tokens=tokens,
updated_at=sqlalchemy.func.current_timestamp()
)
session.add(chunk)
await session.commit()
await session.refresh(chunk)
return chunk
async def update(self, chunk: ConversationChunkModel):
chunk.updated_at = sqlalchemy.func.current_timestamp()
chunk = await self.session.merge(chunk)
async def add(self, obj: ConversationChunkModel):
obj.updated_at = int(time.time())
self.session.add(obj)
await self.session.commit()
return chunk
await self.session.refresh(obj)
return obj
async def update(self, obj: ConversationChunkModel):
obj.updated_at = int(time.time())
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def update_message_log(self, chunk_id: int, message_data: list, tokens: int):
stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \
.values(message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp())
.values(message_data=message_data, tokens=tokens, updated_at=int(time.time()))
await self.session.execute(stmt)
await self.session.commit()
async def get_newest_chunk(self, conversation_id: int):
stmt = sqlalchemy.select(ConversationChunkModel) \
stmt = select(ConversationChunkModel) \
.where(ConversationChunkModel.conversation_id == conversation_id) \
.order_by(ConversationChunkModel.id.desc()) \
.limit(1)
return await self.session.scalar(stmt)
async def remove(self, id: int):
async def get_chunk_id_list(self, conversation_id: int):
stmt = select(ConversationChunkModel.id) \
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(ConversationChunkModel).where(ConversationChunkModel.id == id)
return await self.session.scalar(stmt)
async def remove(self, id: int | list[int]):
if isinstance(id, list):
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id.in_(id))
else:
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id)
await self.session.execute(stmt)
await self.session.commit()
async def remove_by_conversation_id(self, conversation_id: int):
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id == conversation_id)
async def remove_by_conversation_ids(self, ids: list[int]):
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id.in_(ids))
await self.session.execute(stmt)
await self.session.commit()
async def on_conversation_removed(event):
if "conversation" in event:
conversation_info = event["conversation"]
conversation_id = conversation_info["id"]
await ConversationChunkHelper(event["dbs"]).remove_by_conversation_id(conversation_id)
if "ids" in event:
conversation_ids = event["ids"]
async with ConversationChunkHelper(event["dbs"]) as chunk_helper:
await chunk_helper.remove_by_conversation_ids(conversation_ids)
EventService.create().add_listener("conversation/removed/chatcomplete", on_conversation_removed)
EventService.create().add_listener("conversation/removed", on_conversation_removed)

@ -1,4 +1,5 @@
from __future__ import annotations
import time
from typing import List, Optional
import sqlalchemy
@ -17,11 +18,11 @@ class ConversationModel(BaseModel):
module: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
thumbnail: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True)
description: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True)
page_id: Mapped[int] = mapped_column(
sqlalchemy.Integer, index=True, nullable=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
updated_at: Mapped[int] = mapped_column(
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
pinned: Mapped[bool] = mapped_column(
sqlalchemy.Boolean, default=False, index=True)
extra: Mapped[dict] = mapped_column(sqlalchemy.JSON, default={})
@ -46,13 +47,8 @@ class ConversationHelper:
await self.session.__aexit__(exc_type, exc, tb)
pass
async def add(self, user_id: int, module: str, title: Optional[str] = None, page_id: Optional[int] = None, rev_id: Optional[int] = None, extra: Optional[dict] = None):
obj = ConversationModel(user_id=user_id, module=module, title=title,
page_id=page_id, rev_id=rev_id, updated_at=sqlalchemy.func.current_timestamp())
if extra is not None:
obj.extra = extra
async def add(self, obj: ConversationModel):
obj.updated_at = int(time.time())
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
@ -60,11 +56,12 @@ class ConversationHelper:
async def refresh_updated_at(self, conversation_id: int):
stmt = update(ConversationModel).where(ConversationModel.id ==
conversation_id).values(updated_at=sqlalchemy.func.current_timestamp())
conversation_id).values(updated_at=int(time.time()))
await self.session.execute(stmt)
await self.session.commit()
async def update(self, obj: ConversationModel):
obj.updated_at = int(time.time())
await self.session.merge(obj)
await self.session.commit()
await self.session.refresh(obj)
@ -85,14 +82,25 @@ class ConversationHelper:
return await self.session.scalars(stmt)
async def find_by_id(self, conversation_id: int):
async with self.create_session() as session:
async def find_by_id(self, id: int):
stmt = sqlalchemy.select(ConversationModel).where(
ConversationModel.id == conversation_id)
return await session.scalar(stmt)
ConversationModel.id == id)
return await self.session.scalar(stmt)
async def filter_user_owned_ids(self, ids: list[int], user_id: int) -> list[int]:
stmt = sqlalchemy.select(ConversationModel.id) \
.where(ConversationModel.id.in_(ids)).where(ConversationModel.user_id == user_id)
return list(await self.session.scalars(stmt))
async def remove(self, conversation_id: int):
stmt = sqlalchemy.delete(ConversationModel).where(
ConversationModel.id == conversation_id)
await self.session.execute(stmt)
await self.session.commit()
async def remove_multiple(self, ids: list[int]):
stmt = sqlalchemy.delete(ConversationModel) \
.where(ConversationModel.id.in_(ids))
await self.session.execute(stmt)
await self.session.commit()

@ -1,5 +1,5 @@
from __future__ import annotations
import datetime
import time
from typing import Optional
import sqlalchemy
@ -17,8 +17,7 @@ class PageTitleModel(BaseModel):
sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
updated_at: Mapped[int] = mapped_column(
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
class PageTitleHelper:
@ -58,11 +57,11 @@ class PageTitleHelper:
title_info = await self.find_by_title(title)
if title_info is None:
return True
if title_info.updated_at < (datetime.now() - datetime.timedelta(days=7)):
if time.time() - title_info.updated_at > 60:
return True
async def add(self, page_id: int, title: Optional[str] = None):
obj = PageTitleModel(page_id=page_id, title=title, updated_at=sqlalchemy.func.current_timestamp())
obj = PageTitleModel(page_id=page_id, title=title, updated_at=int(time.time()))
self.session.add(obj)
await self.session.commit()
@ -71,14 +70,14 @@ class PageTitleHelper:
async def set_title(self, page_id: int, title: Optional[str] = None):
stmt = update(PageTitleModel).where(
PageTitleModel.page_id == page_id).values(title=title, updated_at=sqlalchemy.func.current_timestamp())
PageTitleModel.page_id == page_id).values(title=title, updated_at=int(time.time()))
await self.session.execute(stmt)
await self.session.commit()
async def update(self, obj: PageTitleModel, ignore_updated_at: bool = False):
self.session.merge(obj)
await self.session.merge(obj)
if not ignore_updated_at:
obj.updated_at = sqlalchemy.func.current_timestamp()
obj.updated_at = int(time.time())
await self.session.commit()
await self.session.refresh(obj)
return obj

@ -18,17 +18,19 @@ def init(app: web.Application):
web.route('*', '/title/info', Index.update_title_info),
web.route('*', '/user/info', Index.get_user_info),
web.route('*', '/conversations', Index.get_conversation_list),
web.route('*', '/conversation/list', Index.get_conversation_list),
web.route('*', '/conversation/info', Index.get_conversation_info),
web.route('POST', '/conversation/remove', Index.remove_conversation),
web.route('DELETE', '/conversation/remove', Index.remove_conversation),
web.route('POST', '/conversation/set_pinned', Index.set_conversation_pinned),
web.route('POST', '/conversation/set_title', Index.set_conversation_title),
web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page),
web.route('*', '/embedding_search/search', EmbeddingSearch.search),
web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk),
web.route('*', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('*', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk),
web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost),
])

@ -39,8 +39,6 @@ CHATCOMPLETE_OUTPUT_REPLACE = {
"人工智能程式": "虛擬人物程序",
}
CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE = "无标题"
CHATCOMPLETE_BOT_NAME = "寫作助手"
PROMPTS = {

@ -17,7 +17,7 @@ from api.model.embedding_search.title_index import TitleIndexModel as _
from service.tiktoken import TikTokenService
async def index(request: web.Request):
return utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request)
return await utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request)
async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create()

@ -1,34 +1,98 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
import atexit
from functools import wraps
import random
import sys
import traceback
from typing import Callable, Coroutine
from typing import Callable, Coroutine, Optional, TypedDict
class TimerInfo(TypedDict):
id: int
callback: Callable
interval: float
next_time: float
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.timer_map: dict[int, TimerInfo] = {}
self.loop = loop
self.running = True
self.should_refresh_task = False
self.next_timer_time: Optional[float] = None
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
self.timer_task = loop.create_task(self._run_timer())
atexit.register(self.end_task)
async def end(self):
if self.running:
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
await self.timer_task
def end_task(self):
if self.running and not self.loop.is_closed():
self.loop.run_until_complete(self.end())
async def _wrap_task(self, task: Task):
try:
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
finally:
self.should_refresh_task = True
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def add_timer(self, callback: Callable, interval: float) -> int:
id = random.randint(0, 1000000000)
while id in self.timer_map:
id = random.randint(0, 1000000000)
now = self.loop.time()
next_time = now + interval
self.timer_map[id] = {
"id": id,
"callback": callback,
"interval": interval,
"next_time": next_time
}
if self.next_timer_time is None or next_time < self.next_timer_time:
self.next_timer_time = next_time
return id
def remove_timer(self, id: int):
if id in self.timer_map:
del self.timer_map[id]
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
@ -47,7 +111,6 @@ class NoAwaitPool:
for handler in self.on_error:
try:
handler_ret = handler(e)
if handler_ret is Coroutine:
await handler_ret
handled = True
except Exception as handler_err:
@ -58,9 +121,9 @@ class NoAwaitPool:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
if self.should_refresh_task:
should_remove = []
for task in self.task_list:
if task.done():
@ -70,3 +133,33 @@ class NoAwaitPool:
self.task_list.remove(task)
await asyncio.sleep(0.1)
async def _run_timer(self):
while self.running:
now = self.loop.time()
if self.next_timer_time is not None and now >= self.next_timer_time:
self.next_timer_time = None
for timer in self.timer_map.values():
if now >= timer["next_time"]:
timer["next_time"] = now + timer["interval"]
try:
result = timer["callback"]()
self.add_task(result)
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
self.add_task(handler_ret)
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
self.next_timer_time = timer["next_time"]
await asyncio.sleep(0.1)

@ -0,0 +1,5 @@
transformers
--index-url https://download.pytorch.org/whl/cpu
torch
torchvision
torchaudio

@ -0,0 +1,143 @@
from __future__ import annotations
import time
import config
import asyncio
import random
import threading
from typing import Optional, TypedDict
import torch
from transformers import pipeline
from local import loop
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
class BERTEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
class BERTEmbeddingQueue:
def init(self):
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
async def get_embeddings(self, text: str):
text = "[CLS]" + text + "[SEP]"
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model(task["text"])
with self.lock:
task["embedding"] = embeddings[0][1]
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
bert_embedding_queue = BERTEmbeddingQueue()
bert_embedding_queue.init()
class BERTEmbeddingService:
instance = None
@staticmethod
async def create() -> BERTEmbeddingService:
if BERTEmbeddingService.instance is None:
BERTEmbeddingService.instance = BERTEmbeddingService()
await BERTEmbeddingService.instance.init()
return BERTEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_queue = BERTEmbeddingQueue()
await loop.run_in_executor(None, self.embedding_queue.init)
async def get_embeddings(self, docs, on_progress=None):
if len(docs) == 0:
return ([], 0)
if on_progress is not None:
await on_progress(0, len(docs))
embeddings = []
token_usage = 0
for doc in docs:
if "text" in doc:
tokens = await self.tiktoken.get_tokens(doc["text"])
token_usage += tokens
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": self.model.encode(doc["text"]),
"tokens": tokens
})
else:
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": None,
"tokens": 0
})
if on_progress is not None:
await on_progress(1, len(docs))
return (embeddings, token_usage)

@ -1,12 +1,18 @@
from __future__ import annotations
import time
import traceback
from typing import Optional, Tuple, TypedDict
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationChunkModel
import sqlalchemy
from api.model.chat_complete.conversation import (
ConversationChunkHelper,
ConversationChunkModel,
)
import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
import config
import utils.config
import utils.config, utils.web
from aiohttp import web
from api.model.embedding_search.title_collection import TitleCollectionModel
@ -18,21 +24,21 @@ from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
conversation_id: int
chunk_id: int
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
total_tokens: int
finish_reason: str
conversation_id: int
question_message_id: str
response_message_id: str
delta_data: dict
class ChatCompleteService:
def __init__(self, dbs: DatabaseService, title: str):
self.dbs = dbs
@ -58,6 +64,7 @@ class ChatCompleteService:
self.question = ""
self.question_tokens: Optional[int] = None
self.conversation_id: Optional[int] = None
self.conversation_start_time: Optional[int] = None
self.delta_data = {}
@ -81,19 +88,28 @@ class ChatCompleteService:
async def get_question_tokens(self, question: str):
return await self.tiktoken.get_tokens(question)
async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None,
async def prepare_chat_complete(
self,
question: str,
conversation_id: Optional[str] = None,
user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
if user_id is not None:
user_id = int(user_id)
self.user_id = user_id
self.question = question
self.conversation_start_time = int(time.time())
self.conversation_info = None
if conversation_id is not None:
self.conversation_id = int(conversation_id)
self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id)
self.conversation_info = await self.conversation_helper.find_by_id(
self.conversation_id
)
else:
self.conversation_id = None
@ -106,97 +122,201 @@ class ChatCompleteService:
else:
self.question_tokens = question_tokens
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
if (
len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS
and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS
):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
self.conversation_chunk = None
if self.conversation_info is not None:
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)
if edit_message_id and "," in edit_message_id:
(edit_chunk_id, edit_msg_id) = edit_message_id.split(",")
edit_chunk_id = int(edit_chunk_id)
# Remove overrided conversation chunks
start_overrided = False
should_remove_chunk_ids = []
for chunk_id in chunk_id_list:
if start_overrided:
should_remove_chunk_ids.append(chunk_id)
else:
if chunk_id == edit_chunk_id:
start_overrided = True
if len(should_remove_chunk_ids) > 0:
await self.conversation_chunk_helper.remove(should_remove_chunk_ids)
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
self.conversation_id
)
# Remove outdated message
edit_message_pos = None
old_tokens = 0
for i in range(0, len(self.conversation_chunk.message_data)):
msg_data = self.conversation_chunk.message_data[i]
if msg_data["id"] == edit_msg_id:
edit_message_pos = i
break
if "tokens" in msg_data and msg_data["tokens"]:
old_tokens += msg_data["tokens"]
if edit_message_pos:
self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos]
flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens = old_tokens
else:
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
self.conversation_id
)
# If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
summary, tokens = await self.make_summary(
self.conversation_chunk.message_data
)
new_message_log = [
{
"role": "summary",
"content": summary,
"tokens": tokens,
"time": int(time.time()),
}
]
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_id,
message_data=new_message_log,
tokens=tokens,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
else:
# 创建新对话
title_info = self.embedding_search.title_info
self.conversation_info = ConversationModel(
user_id=self.user_id,
module="chatcomplete",
page_id=title_info["page_id"],
rev_id=title_info["rev_id"],
)
self.conversation_info = await self.conversation_helper.add(
self.conversation_info,
)
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_info.id,
message_data=[],
tokens=0,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(
self.conversation_chunk
)
# Extract document from wiki page index
self.extract_doc = None
if embedding_search is not None:
self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
self.extract_doc, token_usage = await self.embedding_search.search(
question, **embedding_search
)
if self.extract_doc is not None:
self.question_tokens += token_usage
return ChatCompleteServicePrepareResponse(
extract_doc=self.extract_doc,
question_tokens=self.question_tokens
question_tokens=self.question_tokens,
conversation_id=self.conversation_info.id,
chunk_id=self.conversation_chunk.id
)
async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse:
async def finish_chat_complete(
self, on_message: Optional[callable] = None
) -> ChatCompleteServiceResponse:
delta_data = {}
self.conversation_chunk = None
message_log = []
if self.conversation_info is not None:
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id)
# If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
summary, tokens = await self.make_summary(self.conversation_chunk.message_data)
new_message_log = [
{"role": "summary", "content": summary, "tokens": tokens}
]
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens)
self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id
message_log = []
if self.conversation_chunk is not None:
for message in self.conversation_chunk.message_data:
message_log.append({
message_log.append(
{
"role": message["role"],
"content": message["content"],
})
}
)
if self.extract_doc is not None:
doc_prompt_content = "\n".join(["%d. %s" % (
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)])
doc_prompt_content = "\n".join(
[
"%d. %s" % (i + 1, doc["markdown"] or doc["text"])
for i, doc in enumerate(self.extract_doc)
]
)
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", {
"content": doc_prompt_content})
doc_prompt = utils.config.get_prompt(
"extracted_doc", "prompt", {"content": doc_prompt_content}
)
message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system_prompt")
# Start chat complete
if on_message is not None:
response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message)
response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, message_log, on_message
)
else:
response = await self.openai_api.chat_complete(self.question, system_prompt, message_log)
if self.conversation_info is None:
# Create a new conversation
message_log_list = [
{"role": "user", "content": self.question, "tokens": self.question_tokens},
{"role": "assistant",
"content": response["message"], "tokens": response["message_tokens"]},
response = await self.openai_api.chat_complete(
self.question, system_prompt, message_log
)
description = response["message"][0:150]
question_msg_id = utils.web.generate_uuid()
response_msg_id = utils.web.generate_uuid()
new_message_data = [
{
"id": question_msg_id,
"role": "user",
"content": self.question,
"tokens": self.question_tokens,
"time": self.conversation_start_time,
},
{
"id": response_msg_id,
"role": "assistant",
"content": response["message"],
"tokens": response["message_tokens"],
"time": int(time.time()),
},
]
if self.conversation_info is not None:
total_token_usage = self.question_tokens + response["message_tokens"]
# Generate title if not exists
if self.conversation_info.title is None:
title = None
try:
title, token_usage = await self.make_title(message_log_list)
title, token_usage = await self.make_title(new_message_data)
delta_data["title"] = title
except Exception as e:
title = config.CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE
print(str(e), file=sys.stderr)
traceback.print_exc(file=sys.stderr)
total_token_usage = self.question_tokens + response["message_tokens"]
self.conversation_info.title = title
title_info = self.embedding_search.title_info
self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage)
else:
# Update the conversation chunk
await self.conversation_helper.refresh_updated_at(self.conversation_id)
# Update conversation info
self.conversation_info.description = description
await self.conversation_helper.update(self.conversation_info)
self.conversation_chunk.message_data.append(
{"role": "user", "content": self.question, "tokens": self.question_tokens})
self.conversation_chunk.message_data.append(
{"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]})
# Update conversation chunk
self.conversation_chunk.message_data.extend(new_message_data)
flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens += self.question_tokens + \
response["message_tokens"]
self.conversation_chunk.tokens += total_token_usage
await self.conversation_chunk_helper.update(self.conversation_chunk)
@ -205,8 +325,9 @@ class ChatCompleteService:
message_tokens=response["message_tokens"],
total_tokens=response["total_tokens"],
finish_reason=response["finish_reason"],
conversation_id=self.conversation_info.id,
delta_data=delta_data
question_message_id=question_msg_id,
response_message_id=response_msg_id,
delta_data=delta_data,
)
async def set_latest_point_cost(self, point_cost: int) -> bool:
@ -224,44 +345,50 @@ class ChatCompleteService:
return True
async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
for message_data in message_log_list:
if message_data["role"] == 'summary':
if message_data["role"] == "summary":
chat_log.append(message_data["content"])
elif message_data["role"] == 'assistant':
elif message_data["role"] == "assistant":
chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}')
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}'
)
else:
chat_log.append(f'User: {message_data["content"]}')
chat_log_str = '\n'.join(chat_log)
chat_log_str = "\n".join(chat_log)
summary_system_prompt = utils.config.get_prompt(
"summary", "system_prompt")
summary_system_prompt = utils.config.get_prompt("summary", "system_prompt")
summary_prompt = utils.config.get_prompt(
"summary", "prompt", {"content": chat_log_str})
"summary", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(summary_prompt, summary_system_prompt)
response = await self.openai_api.chat_complete(
summary_prompt, summary_system_prompt
)
return response["message"], response["message_tokens"]
async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
for message_data in message_log_list:
if message_data["role"] == 'assistant':
if message_data["role"] == "assistant":
chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}')
elif message_data["role"] == 'user':
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}'
)
elif message_data["role"] == "user":
chat_log.append(f'User: {message_data["content"]}')
chat_log_str = '\n'.join(chat_log)
chat_log_str = "\n".join(chat_log)
title_system_prompt = utils.config.get_prompt("title", "system_prompt")
title_prompt = utils.config.get_prompt(
"title", "prompt", {"content": chat_log_str})
"title", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(title_prompt, title_system_prompt)
response = await self.openai_api.chat_complete(
title_prompt, title_system_prompt
)
return response["message"], response["message_tokens"]

@ -9,12 +9,17 @@ from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences
class EmbeddingRunningException(Exception):
pass
class EmbeddingSearchArgs(TypedDict):
limit: Optional[int]
in_collection: Optional[bool]
distance_limit: Optional[float]
class EmbeddingSearchService:
indexing_page_ids: list[int] = []
def __init__(self, dbs: DatabaseService, title: str):
self.dbs = dbs
@ -92,6 +97,9 @@ class EmbeddingSearchService:
self.page_id = self.page_info["pageid"]
if self.page_id in self.indexing_page_ids:
raise EmbeddingRunningException("Page index is running now")
# Create collection
self.collection_info = await self.title_collection.find_by_title(self.base_title)
if self.collection_info is None:
@ -129,7 +137,6 @@ class EmbeddingSearchService:
if self.unindexed_docs is None:
return False
chunk_limit = 500
chunk_len = 0

@ -1,3 +1,4 @@
from __future__ import annotations
import json
import sys
import time
@ -15,24 +16,49 @@ class MediaWikiApiException(Exception):
def __str__(self) -> str:
return self.info
class MediaWikiPageNotFoundException(MediaWikiApiException):
pass
class MediaWikiPageNotFoundException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:
super().__init__(info)
self.info = info
self.code = code
self.message = self.info
def __str__(self) -> str:
return self.info
class MediaWikiUserNoEnoughPointsException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:
super().__init__(info)
self.info = info
self.code = code
self.message = self.info
def __str__(self) -> str:
return self.info
class ChatCompleteGetPointUsageResponse(TypedDict):
point_cost: int
class ChatCompleteReportUsageResponse(TypedDict):
point_cost: int
transaction_id: str
class MediaWikiApi:
cookie_jar = aiohttp.CookieJar(unsafe=True)
instance: MediaWikiApi = None
@staticmethod
def create():
return MediaWikiApi(config.MW_API)
if MediaWikiApi.instance is None:
MediaWikiApi.instance = MediaWikiApi(config.MW_API)
return MediaWikiApi.instance
def __init__(self, api_url: str):
self.api_url = api_url
self.login_time = 0.0
self.cookie_jar = aiohttp.CookieJar(unsafe=True)
self.login_identity = None
self.login_time = 0.0
async def get_page_info(self, title: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -50,7 +76,7 @@ class MediaWikiApi:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
if "missing" in data["query"]["pages"][0]:
raise MediaWikiPageNotFoundException()
raise MediaWikiPageNotFoundException(data["error"]["info"], data["error"]["code"])
return data["query"]["pages"][0]
@ -99,6 +125,20 @@ class MediaWikiApi:
return ret
async def is_logged_in(self,):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
"action": "query",
"format": "json",
"formatversion": "2",
"meta": "userinfo"
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return data["query"]["userinfo"]["id"] != 0
async def get_token(self, token_type: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -145,8 +185,10 @@ class MediaWikiApi:
async def refresh_login(self):
if self.login_identity is None:
print("刷新MW机器人账号登录状态失败没有保存的用户")
return False
if time.time() - self.login_time > 30:
if time.time() - self.login_time > 3600:
print("刷新MW机器人账号登录状态")
return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
async def chat_complete_user_info(self, user_id: int):
@ -170,6 +212,32 @@ class MediaWikiApi:
return data["chatcompletebot"]["userinfo"]
async def chat_complete_get_point_cost(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteGetPointUsageResponse:
await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
post_data = {
"action": "chatcompletebot",
"method": "reportusage",
"step": "check",
"userid": int(user_id) if user_id is not None else None,
"useraction": user_action,
"tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None,
"format": "json",
"formatversion": "2",
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
print(data)
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
point_cost = int(data["chatcompletebot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteGetPointUsageResponse(point_cost=point_cost)
async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse:
await self.refresh_login()
@ -178,10 +246,10 @@ class MediaWikiApi:
"action": "chatcompletebot",
"method": "reportusage",
"step": "start",
"userid": int(user_id),
"userid": int(user_id) if user_id is not None else None,
"useraction": user_action,
"tokens": int(tokens),
"extractlines": int(extractlines),
"tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None,
"format": "json",
"formatversion": "2",
}
@ -190,10 +258,13 @@ class MediaWikiApi:
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
print(data)
if data["error"]["code"] == "noenoughpoints":
raise MediaWikiUserNoEnoughPointsException(data["error"]["info"], data["error"]["info"])
else:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return ChatCompleteReportUsageResponse(point_cost=int(data["chatcompletebot"]["reportusage"]["pointcost"]),
point_cost = int(data["chatcompletebot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteReportUsageResponse(point_cost=point_cost,
transaction_id=data["chatcompletebot"]["reportusage"]["transactionid"])
async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
@ -237,7 +308,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.get(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])

@ -1,4 +1,5 @@
from __future__ import annotations
import asyncio
import json
from typing import Callable, Optional, TypedDict
@ -36,20 +37,20 @@ class OpenAIApi:
def build_header(self):
if config.OPENAI_API_TYPE == "azure":
return {
"Content-Type": "application/json",
"Accept": "application/json",
"content-type": "application/json",
"accept": "application/json",
"api-key": self.api_key
}
else:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
"authorization": f"Bearer {self.api_key}",
"content-type": "application/json",
"accept": "application/json",
}
def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure":
if method == "completions":
if method == "chat/completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method
elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method
@ -92,6 +93,10 @@ class OpenAIApi:
if config.OPENAI_API_TYPE == "azure":
# Azure api does not support batch
for index, text in enumerate(text_list):
retry_num = 0
max_retry_num = 3
while retry_num < max_retry_num:
try:
async with session.post(url,
headers=self.build_header(),
params=params,
@ -112,6 +117,17 @@ class OpenAIApi:
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
break
except Exception as e:
retry_num += 1
if retry_num >= max_retry_num:
raise e
print("Error: %s" % e)
print("Retrying...")
await asyncio.sleep(0.5)
else:
async with session.post(url,
headers=self.build_header(),
@ -158,7 +174,7 @@ class OpenAIApi:
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation)
url = self.get_url("completions")
url = self.get_url("chat/completions")
params = {}
post_data = {
@ -175,7 +191,7 @@ class OpenAIApi:
async with aiohttp.ClientSession() as session:
async with session.post(url,
headers=self.build_header,
headers=self.build_header(),
params=params,
json=post_data,
timeout=30,
@ -210,24 +226,38 @@ class OpenAIApi:
for message in messageList:
prompt_tokens += await tiktoken.get_tokens(message["content"])
params = {
"model": "gpt-3.5-turbo",
url = self.get_url("chat/completions")
params = {}
post_data = {
"messages": messageList,
"stream": True,
"user": user,
"stream": True,
"n": 1,
"max_tokens": 768,
"stop": None,
"temperature": 1,
"top_p": 0.95
}
params = {k: v for k, v in params.items() if v is not None}
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
post_data = {k: v for k, v in post_data.items() if v is not None}
res_message: list[str] = []
finish_reason = None
async with sse_client.EventSource(
self.api_url + "/v1/chat/completions",
url,
option={
"method": "POST"
},
headers={"Authorization": f"Bearer {self.api_key}"},
json=params,
headers=self.build_header(),
params=params,
json=post_data,
proxy=config.REQUEST_PROXY
) as session:
async for event in session:
@ -239,10 +269,12 @@ class OpenAIApi:
"""
content_started = False
if event.data == "[DONE]":
event_data = event.data.strip()
if event_data == "[DONE]":
break
elif event.data[0] == "{" and event.data[-1] == "}":
data = json.loads(event.data)
elif event_data[0] == "{" and event_data[-1] == "}":
data = json.loads(event_data)
if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0]
@ -262,12 +294,15 @@ class OpenAIApi:
res_message.append(delta_message)
if config.DEBUG:
print(delta_message, end="", flush=True)
# if config.DEBUG:
# print(delta_message, end="", flush=True)
if on_message is not None:
await on_message(delta_message)
if finish_reason is not None:
break
res_message_str = "".join(res_message)
message_tokens = await tiktoken.get_tokens(res_message_str)
total_tokens = prompt_tokens + message_tokens

@ -1,50 +0,0 @@
from __future__ import annotations
import asyncpg
class SimpleQueryBuilder:
def __init__(self):
self._table_name = ""
self._select = ["*"]
self._where = []
self._having = []
self._order_by = None
self._order_by_desc = False
def table(self, table_name: str):
self._table_name = table_name
return self
def fields(self, fields: list[str]):
self.select = fields
return self
def where(self, where: str, condition: str, param):
self._where.append((where, condition, param))
return self
def having(self, having: str, condition: str, param):
self._having.append((having, condition, param))
return self
def build(self):
sql = "SELECT %s FROM %s" % (", ".join(self._select), self._table_name)
params = []
paramsLen = 0
if len(self._where) > 0:
sql += " WHERE "
for where, condition, param in self._where:
params.append(param)
paramsLen += 1
sql += "%s %s $%d AND " % (where, condition, paramsLen)
if self._order_by is not None:
sql += " ORDER BY %s %s" % (self._order_by, "DESC" if self._order_by_desc else "ASC")
if len(self._having) > 0:
sql += " HAVING "
for having, condition, param in self._having:
params.append(param)
paramsLen += 1
sql += "%s %s $%d AND " % (having, condition, paramsLen)

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))

@ -0,0 +1,27 @@
import asyncio
import time
import base
from local import loop, noawait
from service.bert_embedding import bert_embedding_queue
async def main():
embedding_list = []
start_time = time.time()
queue = []
with open("test/test.md", "r", encoding="utf-8") as fp:
text = fp.read()
lines = text.split("\n")
for line in lines:
line = line.strip()
if line == "":
continue
queue.append(bert_embedding_queue.get_embeddings(line))
embedding_list = await asyncio.gather(*queue)
end_time = time.time()
print("time cost: %.4f" % (end_time - start_time))
print("dimensions: %d" % len(embedding_list[0]))
await noawait.end()
if __name__ == '__main__':
loop.run_until_complete(main())

@ -0,0 +1,37 @@
import traceback
import base
import local
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.tiktoken import TikTokenService
async def main():
dbs = await DatabaseService.create()
tiktoken = await TikTokenService.create()
async with ChatCompleteService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as chat_complete:
question = "你是谁?"
question_tokens = await tiktoken.get_tokens(question)
try:
prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, {
"distance_limit": 0.6,
"limit": 10
})
print(prepare_res)
async def on_message(message: str):
# print(message)
pass
res = await chat_complete.finish_chat_complete(on_message)
print(res)
except Exception as err:
print(err)
traceback.print_exc()
await local.noawait.end()
if __name__ == '__main__':
local.loop.run_until_complete(main())

@ -1,3 +1,4 @@
import base
import local
from service.database import DatabaseService

@ -0,0 +1,20 @@
import asyncio
import base
from local import loop, noawait
async def test_timer1():
print("timer1")
async def test_timer2():
print("timer2")
async def main():
timer_id = noawait.add_timer(test_timer1, 1)
timer_id = noawait.add_timer(test_timer2, 2)
print("Timer id: %d" % timer_id)
while True:
await asyncio.sleep(1)
await noawait.end()
if __name__ == '__main__':
loop.run_until_complete(main())

@ -1,5 +1,6 @@
from __future__ import annotations
from functools import wraps
import json
from typing import Any, Optional, Dict
from aiohttp import web
import jwt
@ -8,22 +9,32 @@ import uuid
ParamRule = Dict[str, Any]
class ParamInvalidException(Exception):
class ParamInvalidException(web.HTTPBadRequest):
def __init__(self, param_list: list[str], rules: dict[str, ParamRule]):
self.code = "param_invalid"
self.param_list = param_list
self.rules = rules
param_list_str = "'" + ("', '".join(param_list)) + "'"
super().__init__(f"Param invalid: {param_list_str}")
super().__init__(f"Param invalid: {param_list_str}",
content_type="application/json",
body=json.dumps({
"status": -1,
"error": {
"code": self.code,
"message": f"Param invalid: {param_list_str}"
}
}))
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
params: dict[str, Any] = {}
for key, value in request.match_info.items():
params[key] = value
for key, value in request.query.items():
params[key] = value
if request.method == 'POST':
if request.headers.get('content-type') == 'application/json':
data = await request.json()
if data is not None and data is dict:
if data is not None and isinstance(data, dict):
for key, value in data.items():
params[key] = value
else:
@ -34,7 +45,7 @@ async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]]
if rules is not None:
invalid_params: list[str] = []
for key, rule in rules.items():
if "required" in rule and rule["required"] and params[key] is None:
if "required" in rule and rule["required"] and key not in params.keys():
invalid_params.append(key)
continue
@ -50,11 +61,22 @@ async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]]
elif rule["type"] == float:
params[key] = float(params[key])
elif rule["type"] == bool:
val = params[key].lower()
if val == "false" or val == "0":
val = params[key]
if isinstance(val, bool):
params[key] = val
elif isinstance(val, str):
val = val.lower()
if val.lower() == "false" or val == "0":
params[key] = False
else:
params[key] = True
elif isinstance(val, int):
if val == 0:
params[key] = False
else:
params[key] = True
else:
params[key] = True
except ValueError:
invalid_params.append(key)
continue

Loading…
Cancel
Save