更改流式输出模式

master
落雨楓 2 years ago
parent 6a74050b56
commit 68d1647d5d

@ -1,19 +1,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json
import sys import sys
import time import time
import traceback import traceback
from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait from local import noawait
from typing import Optional, Callable, TypedDict from typing import Optional, Callable, TypedDict
from aiohttp import web from aiohttp import web
from sqlalchemy import select from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService from service.chat_complete import ChatCompleteService, ChatCompleteServiceResponse
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
import utils.web import utils.web
@ -23,6 +23,7 @@ class ChatCompleteTask:
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False): def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid() self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = [] self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = [] self.on_error: list[Callable] = []
self.chunks: list[str] = [] self.chunks: list[str] = []
@ -37,7 +38,12 @@ class ChatCompleteTask:
self.transatcion_id: Optional[str] = None self.transatcion_id: Optional[str] = None
self.point_cost: int = 0 self.point_cost: int = 0
async def init(self, question: str, conversation_id: Optional[str] = None, self.is_finished = False
self.finished_time: Optional[float] = None
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None): embedding_search: Optional[EmbeddingSearchArgs] = None):
self.tiktoken = await TikTokenService.create() self.tiktoken = await TikTokenService.create()
@ -54,34 +60,58 @@ class ChatCompleteTask:
self.transatcion_id: Optional[str] = None self.transatcion_id: Optional[str] = None
self.point_cost: int = 0 self.point_cost: int = 0
if not self.is_system: if not self.is_system:
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit) usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete",
self.transatcion_id = usage_res.get("transaction_id") question_tokens, extract_limit)
self.point_cost = usage_res.get("point_cost") self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, embedding_search=embedding_search)
return chat_res res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
return res
else: else:
await self._exit() await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str): async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message: for callback in self.on_message:
await callback(delta_message) try:
await callback(delta_message)
except Exception as e:
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_error(self, err: Exception): async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error: for callback in self.on_error:
await callback(err) try:
await callback(err)
except Exception as e:
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def run(self): async def run(self):
try: try:
chat_res = self.chat_complete.finish_chat_complete(self._on_message) chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost) await self.chat_complete.set_latest_point_cost(self.point_cost)
self.result = chat_res
if self.transatcion_id: if self.transatcion_id:
result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"]) await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
await self._on_finished()
except Exception as e: except Exception as e:
err_msg = f"Error while processing chat complete request: {e}" err_msg = f"Error while processing chat complete request: {e}"
@ -89,7 +119,7 @@ class ChatCompleteTask:
traceback.print_exc() traceback.print_exc()
if self.transatcion_id: if self.transatcion_id:
result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg) await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e) await self._on_error(e)
finally: finally:
@ -98,10 +128,19 @@ class ChatCompleteTask:
async def _exit(self): async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None) await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id] del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()
@noawait.wrap TASK_EXPIRE_TIME = 60 * 10
async def start(self):
await self.run() async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)
class ChatComplete: class ChatComplete:
@staticmethod @staticmethod
@ -112,7 +151,7 @@ class ChatComplete:
"required": False, "required": False,
"type": int "type": int
}, },
"conversation_id": { "id": {
"required": True, "required": True,
"type": int "type": int
} }
@ -123,40 +162,31 @@ class ChatComplete:
else: else:
user_id = params.get("user_id") user_id = params.get("user_id")
conversation_id = params.get("conversation_id") conversation_id = params.get("id")
db = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
async with db.create_session() as session: async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
stmt = select(ConversationModel).where( conversation_info = await conversation_helper.find_by_id(conversation_id)
ConversationModel.id == conversation_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None: if conversation_info is None:
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "conversation-not-found", "code": "conversation-not-found",
"message": "Conversation not found." "message": "Conversation not found."
}, http_status=404, request=request) }, http_status=404, request=request)
if conversation_data.user_id != user_id: if conversation_info.user_id != user_id:
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "permission-denied", "code": "permission-denied",
"message": "Permission denied." "message": "Permission denied."
}, http_status=403, request=request) }, http_status=403, request=request)
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \ conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id)
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
conversation_chunk_result = await session.scalars(stmt)
conversation_chunk_list = [] conversation_chunk_list = []
for result in conversation_chunk_result: for result in conversation_chunk_result:
conversation_chunk_list.append({ conversation_chunk_list.append(result)
"id": result.id,
"updated_at": result.updated_at
})
return await utils.web.api_response(1, conversation_chunk_list, request=request) return await utils.web.api_response(1, conversation_chunk_list, request=request)
@ -181,26 +211,32 @@ class ChatComplete:
chunk_id = params.get("chunk_id") chunk_id = params.get("chunk_id")
dbs = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
async with dbs.create_session() as session: async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
stmt = select(ConversationChunkModel).where( chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
ConversationChunkModel.id == chunk_id) if chunk_info is None:
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "conversation-chunk-not-found", "code": "conversation-chunk-not-found",
"message": "Conversation chunk not found." "message": "Conversation chunk not found."
}, http_status=404, request=request) }, http_status=404, request=request)
if conversation_data.conversation.user_id != user_id: conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id)
if conversation_info is not None and conversation_info.user_id != user_id:
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "permission-denied", "code": "permission-denied",
"message": "Permission denied." "message": "Permission denied."
}, http_status=403, request=request) }, http_status=403, request=request)
return await utils.web.api_response(1, conversation_data.__dict__, request=request) chunk_dict = {
"id": chunk_info.id,
"conversation_id": chunk_info.conversation_id,
"message_data": chunk_info.message_data,
"tokens": chunk_info.tokens,
"updated_at": chunk_info.updated_at,
}
return await utils.web.api_response(1, chunk_dict, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
@ -219,6 +255,47 @@ class ChatComplete:
return await utils.web.api_response(1, {"tokens": tokens}, request=request) return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod
@utils.web.token_auth
async def get_point_cost(request: web.Request):
params = await utils.web.get_param(request, {
"question": {
"type": str,
"required": True,
},
"extract_limit": {
"type": int,
"required": False,
"default": 10,
},
})
user_id = request.get("user")
caller = request.get("caller")
question = params.get("question")
extract_limit = params.get("extract_limit")
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create()
tokens = await tiktoken.get_tokens(question)
try:
res = await mwapi.chat_complete_get_point_cost(user_id, "chatcomplete", tokens, extract_limit)
return await utils.web.api_response(1, {
"point_cost": res["point_cost"],
"tokens": tokens,
}, request=request)
except Exception as e:
err_msg = f"Error while get chat complete point cost: {e}"
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "chat-complete-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def start_chat_complete(request: web.Request): async def start_chat_complete(request: web.Request):
@ -245,6 +322,10 @@ class ChatComplete:
"required": False, "required": False,
"default": False, "default": False,
}, },
"edit_message_id": {
"type": str,
"required": False,
},
}) })
user_id = request.get("user") user_id = request.get("user")
@ -257,20 +338,25 @@ class ChatComplete:
extract_limit = params.get("extract_limit") extract_limit = params.get("extract_limit")
in_collection = params.get("in_collection") in_collection = params.get("in_collection")
edit_message_id = params.get("edit_message_id")
dbs = await DatabaseService.create(request.app) dbs = await DatabaseService.create(request.app)
try: try:
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user") chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={ init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id,
embedding_search={
"limit": extract_limit, "limit": extract_limit,
"in_collection": in_collection, "in_collection": in_collection,
}) })
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
chat_complete_task.start() noawait.add_task(chat_complete_task.run())
return utils.web.api_response(1, data={ return await utils.web.api_response(1, data={
"conversation_id": init_res["conversation_id"],
"chunk_id": init_res["chunk_id"],
"question_tokens": init_res["question_tokens"], "question_tokens": init_res["question_tokens"],
"extract_doc": init_res["extract_doc"], "extract_doc": init_res["extract_doc"],
"task_id": chat_complete_task.task_id, "task_id": chat_complete_task.task_id,
@ -282,6 +368,12 @@ class ChatComplete:
"title": page_title, "title": page_title,
"message": error_msg "message": error_msg
}, http_status=404, request=request) }, http_status=404, request=request)
except MediaWikiUserNoEnoughPointsException as e:
error_msg = "Does not have enough points." % user_id
return await utils.web.api_response(-1, error={
"code": "no-enough-points",
"message": error_msg
}, http_status=403, request=request)
except Exception as e: except Exception as e:
err_msg = f"Error while processing chat complete request: {e}" err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc() traceback.print_exc()
@ -294,4 +386,123 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def chat_complete_stream(request: web.Request): async def chat_complete_stream(request: web.Request):
pass if not utils.web.is_websocket(request):
return await utils.web.api_response(-1, error={
"code": "websocket-required",
"message": "This API only accept websocket connection."
}, http_status=400, request=request)
params = await utils.web.get_param(request, {
"task_id": {
"type": str,
"required": True,
}
})
ws = web.WebSocketResponse()
await ws.prepare(request)
task_id = params.get("task_id")
task = chat_complete_tasks.get(task_id)
if task is None:
await ws.send_json({
'event': 'error',
'status': -1,
'message': "Task not found.",
'error': {
'code': "task-not-found",
'info': "Task not found.",
},
})
return
if request.get("caller") == "user":
user_id = request.get("user")
if task.user_id != user_id:
await ws.send_json({
'event': 'error',
'status': -1,
'message': "Permission denied.",
'error': {
'code': "permission-denied",
'info': "Permission denied.",
},
})
return
if task.is_finished:
if task.error is not None:
await ws.send_json({
'event': 'error',
'status': -1,
'message': str(task.error),
'error': {
'code': "internal-error",
'info': str(task.error),
},
})
await ws.close()
elif task.result is not None:
await ws.send_json({
'event': 'connected',
'status': 1,
'outputed_message': "".join(task.chunks),
})
await ws.send_json({
'event': 'finished',
'status': 1,
'result': task.result
})
await ws.close()
else:
async def on_message(delta_message: str):
await ws.send_str("+" + delta_message)
async def on_finished(result: ChatCompleteServiceResponse):
ignored_keys = ["message"]
response_result = {
"point_cost": task.point_cost,
}
for k, v in result.items():
if k not in ignored_keys:
response_result[k] = v
await ws.send_json({
'event': 'finished',
'status': 1,
'result': response_result
})
await ws.close()
async def on_error(err: Exception):
await ws.send_json({
'event': 'error',
'status': -1,
'message': str(err),
'error': {
'code': "internal-error",
'info': str(err),
},
})
await ws.close()
task.on_message.append(on_message)
task.on_finished.append(on_finished)
task.on_error.append(on_error)
# Send received message
await ws.send_json({
'event': 'connected',
'status': 1,
'outputed_message': "".join(task.chunks),
})
while True:
if ws.closed:
task.on_message.remove(on_message)
task.on_finished.remove(on_finished)
task.on_error.remove(on_error)
break
await asyncio.sleep(0.1)

@ -2,7 +2,7 @@ import sys
import traceback import traceback
from aiohttp import web from aiohttp import web
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
import utils.web import utils.web
@ -87,6 +87,18 @@ class EmbeddingSearch:
}) })
if transatcion_id: if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg) await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
except EmbeddingRunningException:
error_msg = "Page index is running now"
await ws.send_json({
'event': 'error',
'status': -4,
'message': error_msg,
'error': {
'code': 'page_index_running',
},
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
print(error_msg, file=sys.stderr) print(error_msg, file=sys.stderr)
@ -94,7 +106,10 @@ class EmbeddingSearch:
await ws.send_json({ await ws.send_json({
'event': 'error', 'event': 'error',
'status': -1, 'status': -1,
'message': error_msg 'message': error_msg,
'error': {
'code': 'internal_server_error',
}
}) })
if transatcion_id: if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg) await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
@ -145,6 +160,16 @@ class EmbeddingSearch:
"info": e.info, "info": e.info,
"message": error_msg "message": error_msg
}, http_status=500) }, http_status=500)
except EmbeddingRunningException:
error_msg = "Page index is running now"
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-4, error={
"code": "page-index-running",
"message": error_msg
}, http_status=429)
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)

@ -2,7 +2,6 @@ import sys
import time import time
import traceback import traceback
from aiohttp import web from aiohttp import web
from sqlalchemy import select
from api.model.toolkit_ui.conversation import ConversationHelper from api.model.toolkit_ui.conversation import ConversationHelper
from api.model.toolkit_ui.page_title import PageTitleHelper from api.model.toolkit_ui.page_title import PageTitleHelper
from service.database import DatabaseService from service.database import DatabaseService
@ -27,7 +26,7 @@ class Index:
db = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
async with PageTitleHelper(db) as page_title_helper: async with PageTitleHelper(db) as page_title_helper:
title_info = await page_title_helper.find_by_title(title) title_info = await page_title_helper.find_by_title(title)
if title_info is not None and time.time() - title_info.updated_at < 60: if title_info is not None and time.time() - title_info.updated_at < 60:
return await utils.web.api_response(1, { return await utils.web.api_response(1, {
"cached": True, "cached": True,
@ -123,6 +122,7 @@ class Index:
"id": result.id, "id": result.id,
"module": result.module, "module": result.module,
"title": result.title, "title": result.title,
"description": result.description,
"thumbnail": result.thumbnail, "thumbnail": result.thumbnail,
"rev_id": result.rev_id, "rev_id": result.rev_id,
"updated_at": result.updated_at, "updated_at": result.updated_at,
@ -166,6 +166,7 @@ class Index:
"id": conversation_info.id, "id": conversation_info.id,
"module": conversation_info.module, "module": conversation_info.module,
"title": conversation_info.title, "title": conversation_info.title,
"description": conversation_info.description,
"thumbnail": conversation_info.thumbnail, "thumbnail": conversation_info.thumbnail,
"rev_id": conversation_info.rev_id, "rev_id": conversation_info.rev_id,
"updated_at": conversation_info.updated_at, "updated_at": conversation_info.updated_at,
@ -178,14 +179,68 @@ class Index:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def remove_conversation(request: web.Request): async def remove_conversation(request: web.Request):
params = await utils.web.get_param(request, {
"id": {
"type": int
},
"ids": {
"type": str
}
})
conversation_id = params.get("id")
conversation_ids = params.get("ids")
if conversation_id is None and conversation_ids is None:
return await utils.web.api_response(-2, error={
"code": "invalid-params",
"message": "Invalid params."
}, request=request, http_status=400)
if conversation_id is not None:
conversation_ids = [conversation_id]
else:
conversation_ids = conversation_ids.split(",")
conversation_ids = [int(id) for id in conversation_ids]
db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper:
user_id = None
if request.get("caller") == "user":
user_id = int(request.get("user"))
conversation_ids = await conversation_helper.filter_user_owned_ids(conversation_ids, user_id=user_id)
if len(conversation_ids) > 0:
await conversation_helper.remove_multiple(conversation_ids)
# 通知其他模块删除
events = EventService.create()
events.emit("conversation/removed", {
"ids": conversation_ids,
"dbs": db,
"app": request.app,
})
return await utils.web.api_response(1, data={
"count": len(conversation_ids)
}, request=request)
@staticmethod
@utils.web.token_auth
async def set_conversation_pinned(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"id": { "id": {
"required": True, "required": True,
"type": int "type": int
},
"pinned": {
"required": True,
"type": bool
} }
}) })
conversation_id = params.get("id") conversation_id = params.get("id")
pinned = params.get("pinned")
db = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper: async with ConversationHelper(db) as conversation_helper:
@ -203,39 +258,27 @@ class Index:
"message": "Permission denied." "message": "Permission denied."
}, request=request, http_status=403) }, request=request, http_status=403)
await conversation_helper.remove(conversation_info) conversation_info.pinned = pinned
await conversation_helper.update(conversation_info)
# 通知其他模块删除
events = EventService.create()
events.emit("conversation/removed", {
"conversation": conversation_info,
"dbs": db,
"app": request.app,
})
events.emit("conversation/removed/" + conversation_info.module, {
"conversation": conversation_info,
"dbs": db,
"app": request.app,
})
return await utils.web.api_response(1, request=request) return await utils.web.api_response(1, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def set_conversation_pinned(request: web.Request): async def set_conversation_title(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"id": { "id": {
"required": True, "required": True,
"type": int "type": int
}, },
"pinned": { "new_title": {
"required": True, "required": True,
"type": bool "type": str
} }
}) })
conversation_id = params.get("id") conversation_id = params.get("id")
pinned = params.get("pinned") new_title = params.get("new_title")
db = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper: async with ConversationHelper(db) as conversation_helper:
@ -253,7 +296,7 @@ class Index:
"message": "Permission denied." "message": "Permission denied."
}, request=request, http_status=403) }, request=request, http_status=403)
conversation_info.pinned = pinned conversation_info.title = new_title
await conversation_helper.update(conversation_info) await conversation_helper.update(conversation_info)
return await utils.web.api_response(1, request=request) return await utils.web.api_response(1, request=request)

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import time
import sqlalchemy import sqlalchemy
from sqlalchemy import update from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped from sqlalchemy.orm import mapped_column, relationship, Mapped
from api.model.base import BaseModel from api.model.base import BaseModel
@ -13,10 +14,11 @@ class ConversationChunkModel(BaseModel):
__tablename__ = "chat_complete_conversation_chunk" __tablename__ = "chat_complete_conversation_chunk"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True) conversation_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(ConversationModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0) tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True) updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
class ConversationChunkHelper: class ConversationChunkHelper:
def __init__(self, dbs: DatabaseService): def __init__(self, dbs: DatabaseService):
@ -36,52 +38,58 @@ class ConversationChunkHelper:
await self.session.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb)
pass pass
async def add(self, conversation_id: int, message_data: list, tokens: int): async def add(self, obj: ConversationChunkModel):
async with self.create_session() as session: obj.updated_at = int(time.time())
chunk = ConversationChunkModel( self.session.add(obj)
conversation_id=conversation_id, await self.session.commit()
message_data=message_data, await self.session.refresh(obj)
tokens=tokens, return obj
updated_at=sqlalchemy.func.current_timestamp()
)
session.add(chunk)
await session.commit()
await session.refresh(chunk)
return chunk
async def update(self, chunk: ConversationChunkModel): async def update(self, obj: ConversationChunkModel):
chunk.updated_at = sqlalchemy.func.current_timestamp() obj.updated_at = int(time.time())
chunk = await self.session.merge(chunk) obj = await self.session.merge(obj)
await self.session.commit() await self.session.commit()
return chunk return obj
async def update_message_log(self, chunk_id: int, message_data: list, tokens: int): async def update_message_log(self, chunk_id: int, message_data: list, tokens: int):
stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \ stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \
.values(message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp()) .values(message_data=message_data, tokens=tokens, updated_at=int(time.time()))
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
async def get_newest_chunk(self, conversation_id: int): async def get_newest_chunk(self, conversation_id: int):
stmt = sqlalchemy.select(ConversationChunkModel) \ stmt = select(ConversationChunkModel) \
.where(ConversationChunkModel.conversation_id == conversation_id) \ .where(ConversationChunkModel.conversation_id == conversation_id) \
.order_by(ConversationChunkModel.id.desc()) \ .order_by(ConversationChunkModel.id.desc()) \
.limit(1) .limit(1)
return await self.session.scalar(stmt) return await self.session.scalar(stmt)
async def get_chunk_id_list(self, conversation_id: int):
stmt = select(ConversationChunkModel.id) \
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(ConversationChunkModel).where(ConversationChunkModel.id == id)
return await self.session.scalar(stmt)
async def remove(self, id: int): async def remove(self, id: int | list[int]):
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id) if isinstance(id, list):
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id.in_(id))
else:
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id)
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
async def remove_by_conversation_id(self, conversation_id: int): async def remove_by_conversation_ids(self, ids: list[int]):
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id == conversation_id) stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id.in_(ids))
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
async def on_conversation_removed(event): async def on_conversation_removed(event):
if "conversation" in event: if "ids" in event:
conversation_info = event["conversation"] conversation_ids = event["ids"]
conversation_id = conversation_info["id"] async with ConversationChunkHelper(event["dbs"]) as chunk_helper:
await ConversationChunkHelper(event["dbs"]).remove_by_conversation_id(conversation_id) await chunk_helper.remove_by_conversation_ids(conversation_ids)
EventService.create().add_listener("conversation/removed/chatcomplete", on_conversation_removed) EventService.create().add_listener("conversation/removed", on_conversation_removed)

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import time
from typing import List, Optional from typing import List, Optional
import sqlalchemy import sqlalchemy
@ -17,11 +18,11 @@ class ConversationModel(BaseModel):
module: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) module: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
thumbnail: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True) thumbnail: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True)
description: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True)
page_id: Mapped[int] = mapped_column( page_id: Mapped[int] = mapped_column(
sqlalchemy.Integer, index=True, nullable=True) sqlalchemy.Integer, index=True, nullable=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
updated_at: Mapped[int] = mapped_column( updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
pinned: Mapped[bool] = mapped_column( pinned: Mapped[bool] = mapped_column(
sqlalchemy.Boolean, default=False, index=True) sqlalchemy.Boolean, default=False, index=True)
extra: Mapped[dict] = mapped_column(sqlalchemy.JSON, default={}) extra: Mapped[dict] = mapped_column(sqlalchemy.JSON, default={})
@ -46,13 +47,8 @@ class ConversationHelper:
await self.session.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb)
pass pass
async def add(self, user_id: int, module: str, title: Optional[str] = None, page_id: Optional[int] = None, rev_id: Optional[int] = None, extra: Optional[dict] = None): async def add(self, obj: ConversationModel):
obj = ConversationModel(user_id=user_id, module=module, title=title, obj.updated_at = int(time.time())
page_id=page_id, rev_id=rev_id, updated_at=sqlalchemy.func.current_timestamp())
if extra is not None:
obj.extra = extra
self.session.add(obj) self.session.add(obj)
await self.session.commit() await self.session.commit()
await self.session.refresh(obj) await self.session.refresh(obj)
@ -60,11 +56,12 @@ class ConversationHelper:
async def refresh_updated_at(self, conversation_id: int): async def refresh_updated_at(self, conversation_id: int):
stmt = update(ConversationModel).where(ConversationModel.id == stmt = update(ConversationModel).where(ConversationModel.id ==
conversation_id).values(updated_at=sqlalchemy.func.current_timestamp()) conversation_id).values(updated_at=int(time.time()))
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
async def update(self, obj: ConversationModel): async def update(self, obj: ConversationModel):
obj.updated_at = int(time.time())
await self.session.merge(obj) await self.session.merge(obj)
await self.session.commit() await self.session.commit()
await self.session.refresh(obj) await self.session.refresh(obj)
@ -85,14 +82,25 @@ class ConversationHelper:
return await self.session.scalars(stmt) return await self.session.scalars(stmt)
async def find_by_id(self, conversation_id: int): async def find_by_id(self, id: int):
async with self.create_session() as session: stmt = sqlalchemy.select(ConversationModel).where(
stmt = sqlalchemy.select(ConversationModel).where( ConversationModel.id == id)
ConversationModel.id == conversation_id) return await self.session.scalar(stmt)
return await session.scalar(stmt)
async def filter_user_owned_ids(self, ids: list[int], user_id: int) -> list[int]:
stmt = sqlalchemy.select(ConversationModel.id) \
.where(ConversationModel.id.in_(ids)).where(ConversationModel.user_id == user_id)
return list(await self.session.scalars(stmt))
async def remove(self, conversation_id: int): async def remove(self, conversation_id: int):
stmt = sqlalchemy.delete(ConversationModel).where( stmt = sqlalchemy.delete(ConversationModel).where(
ConversationModel.id == conversation_id) ConversationModel.id == conversation_id)
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit()
async def remove_multiple(self, ids: list[int]):
stmt = sqlalchemy.delete(ConversationModel) \
.where(ConversationModel.id.in_(ids))
await self.session.execute(stmt)
await self.session.commit() await self.session.commit()

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import datetime import time
from typing import Optional from typing import Optional
import sqlalchemy import sqlalchemy
@ -17,8 +17,7 @@ class PageTitleModel(BaseModel):
sqlalchemy.Integer, primary_key=True, autoincrement=True) sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
updated_at: Mapped[int] = mapped_column( updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
class PageTitleHelper: class PageTitleHelper:
@ -58,11 +57,11 @@ class PageTitleHelper:
title_info = await self.find_by_title(title) title_info = await self.find_by_title(title)
if title_info is None: if title_info is None:
return True return True
if title_info.updated_at < (datetime.now() - datetime.timedelta(days=7)): if time.time() - title_info.updated_at > 60:
return True return True
async def add(self, page_id: int, title: Optional[str] = None): async def add(self, page_id: int, title: Optional[str] = None):
obj = PageTitleModel(page_id=page_id, title=title, updated_at=sqlalchemy.func.current_timestamp()) obj = PageTitleModel(page_id=page_id, title=title, updated_at=int(time.time()))
self.session.add(obj) self.session.add(obj)
await self.session.commit() await self.session.commit()
@ -71,14 +70,14 @@ class PageTitleHelper:
async def set_title(self, page_id: int, title: Optional[str] = None): async def set_title(self, page_id: int, title: Optional[str] = None):
stmt = update(PageTitleModel).where( stmt = update(PageTitleModel).where(
PageTitleModel.page_id == page_id).values(title=title, updated_at=sqlalchemy.func.current_timestamp()) PageTitleModel.page_id == page_id).values(title=title, updated_at=int(time.time()))
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
async def update(self, obj: PageTitleModel, ignore_updated_at: bool = False): async def update(self, obj: PageTitleModel, ignore_updated_at: bool = False):
self.session.merge(obj) await self.session.merge(obj)
if not ignore_updated_at: if not ignore_updated_at:
obj.updated_at = sqlalchemy.func.current_timestamp() obj.updated_at = int(time.time())
await self.session.commit() await self.session.commit()
await self.session.refresh(obj) await self.session.refresh(obj)
return obj return obj

@ -18,17 +18,19 @@ def init(app: web.Application):
web.route('*', '/title/info', Index.update_title_info), web.route('*', '/title/info', Index.update_title_info),
web.route('*', '/user/info', Index.get_user_info), web.route('*', '/user/info', Index.get_user_info),
web.route('*', '/conversations', Index.get_conversation_list), web.route('*', '/conversation/list', Index.get_conversation_list),
web.route('*', '/conversation/info', Index.get_conversation_info), web.route('*', '/conversation/info', Index.get_conversation_info),
web.route('POST', '/conversation/remove', Index.remove_conversation), web.route('POST', '/conversation/remove', Index.remove_conversation),
web.route('DELETE', '/conversation/remove', Index.remove_conversation), web.route('DELETE', '/conversation/remove', Index.remove_conversation),
web.route('POST', '/conversation/set_pinned', Index.set_conversation_pinned), web.route('POST', '/conversation/set_pinned', Index.set_conversation_pinned),
web.route('POST', '/conversation/set_title', Index.set_conversation_title),
web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page), web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page),
web.route('*', '/embedding_search/search', EmbeddingSearch.search), web.route('*', '/embedding_search/search', EmbeddingSearch.search),
web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list), web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk), web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk),
web.route('*', '/chatcomplete/message', ChatComplete.start_chat_complete), web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('*', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost),
]) ])

@ -39,8 +39,6 @@ CHATCOMPLETE_OUTPUT_REPLACE = {
"人工智能程式": "虛擬人物程序", "人工智能程式": "虛擬人物程序",
} }
CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE = "无标题"
CHATCOMPLETE_BOT_NAME = "寫作助手" CHATCOMPLETE_BOT_NAME = "寫作助手"
PROMPTS = { PROMPTS = {

@ -17,7 +17,7 @@ from api.model.embedding_search.title_index import TitleIndexModel as _
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
async def index(request: web.Request): async def index(request: web.Request):
return utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request) return await utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request)
async def init_mw_api(app: web.Application): async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create() mw_api = MediaWikiApi.create()

@ -1,34 +1,98 @@
from __future__ import annotations from __future__ import annotations
from asyncio import AbstractEventLoop, Task from asyncio import AbstractEventLoop, Task
import asyncio import asyncio
import atexit
from functools import wraps from functools import wraps
import random
import sys import sys
import traceback import traceback
from typing import Callable, Coroutine from typing import Callable, Coroutine, Optional, TypedDict
class TimerInfo(TypedDict):
id: int
callback: Callable
interval: float
next_time: float
class NoAwaitPool: class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop): def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = [] self.task_list: list[Task] = []
self.timer_map: dict[int, TimerInfo] = {}
self.loop = loop self.loop = loop
self.running = True self.running = True
self.should_refresh_task = False
self.next_timer_time: Optional[float] = None
self.on_error: list[Callable] = [] self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc()) self.gc_task = loop.create_task(self._run_gc())
self.timer_task = loop.create_task(self._run_timer())
atexit.register(self.end_task)
async def end(self): async def end(self):
print("Stopping NoAwait Tasks...") if self.running:
self.running = False print("Stopping NoAwait Tasks...")
for task in self.task_list: self.running = False
await self._finish_task(task) for task in self.task_list:
await self._finish_task(task)
await self.gc_task
await self.gc_task
await self.timer_task
def end_task(self):
if self.running and not self.loop.is_closed():
self.loop.run_until_complete(self.end())
async def _wrap_task(self, task: Task):
try:
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
finally:
self.should_refresh_task = True
def add_task(self, coroutine: Coroutine): def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine) task = self.loop.create_task(coroutine)
self.task_list.append(task) self.task_list.append(task)
def add_timer(self, callback: Callable, interval: float) -> int:
id = random.randint(0, 1000000000)
while id in self.timer_map:
id = random.randint(0, 1000000000)
now = self.loop.time()
next_time = now + interval
self.timer_map[id] = {
"id": id,
"callback": callback,
"interval": interval,
"next_time": next_time
}
if self.next_timer_time is None or next_time < self.next_timer_time:
self.next_timer_time = next_time
return id
def remove_timer(self, id: int):
if id in self.timer_map:
del self.timer_map[id]
def wrap(self, f): def wrap(self, f):
@wraps(f) @wraps(f)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
@ -47,8 +111,7 @@ class NoAwaitPool:
for handler in self.on_error: for handler in self.on_error:
try: try:
handler_ret = handler(e) handler_ret = handler(e)
if handler_ret is Coroutine: await handler_ret
await handler_ret
handled = True handled = True
except Exception as handler_err: except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr) print("Exception on error handler: " + str(handler_err), file=sys.stderr)
@ -57,16 +120,46 @@ class NoAwaitPool:
if not handled: if not handled:
print(e, file=sys.stderr) print(e, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
async def _run_gc(self): async def _run_gc(self):
while self.running: while self.running:
should_remove = [] if self.should_refresh_task:
for task in self.task_list: should_remove = []
if task.done(): for task in self.task_list:
await self._finish_task(task) if task.done():
should_remove.append(task) await self._finish_task(task)
for task in should_remove: should_remove.append(task)
self.task_list.remove(task) for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)
async def _run_timer(self):
while self.running:
now = self.loop.time()
if self.next_timer_time is not None and now >= self.next_timer_time:
self.next_timer_time = None
for timer in self.timer_map.values():
if now >= timer["next_time"]:
timer["next_time"] = now + timer["interval"]
try:
result = timer["callback"]()
self.add_task(result)
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
self.add_task(handler_ret)
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
self.next_timer_time = timer["next_time"]
await asyncio.sleep(0.1) await asyncio.sleep(0.1)

@ -0,0 +1,5 @@
transformers
--index-url https://download.pytorch.org/whl/cpu
torch
torchvision
torchaudio

@ -0,0 +1,143 @@
from __future__ import annotations
import time
import config
import asyncio
import random
import threading
from typing import Optional, TypedDict
import torch
from transformers import pipeline
from local import loop
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
class BERTEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
class BERTEmbeddingQueue:
def init(self):
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
async def get_embeddings(self, text: str):
text = "[CLS]" + text + "[SEP]"
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model(task["text"])
with self.lock:
task["embedding"] = embeddings[0][1]
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
bert_embedding_queue = BERTEmbeddingQueue()
bert_embedding_queue.init()
class BERTEmbeddingService:
instance = None
@staticmethod
async def create() -> BERTEmbeddingService:
if BERTEmbeddingService.instance is None:
BERTEmbeddingService.instance = BERTEmbeddingService()
await BERTEmbeddingService.instance.init()
return BERTEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_queue = BERTEmbeddingQueue()
await loop.run_in_executor(None, self.embedding_queue.init)
async def get_embeddings(self, docs, on_progress=None):
if len(docs) == 0:
return ([], 0)
if on_progress is not None:
await on_progress(0, len(docs))
embeddings = []
token_usage = 0
for doc in docs:
if "text" in doc:
tokens = await self.tiktoken.get_tokens(doc["text"])
token_usage += tokens
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": self.model.encode(doc["text"]),
"tokens": tokens
})
else:
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": None,
"tokens": 0
})
if on_progress is not None:
await on_progress(1, len(docs))
return (embeddings, token_usage)

@ -1,12 +1,18 @@
from __future__ import annotations from __future__ import annotations
import time
import traceback import traceback
from typing import Optional, Tuple, TypedDict from typing import Optional, Tuple, TypedDict
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationChunkModel
import sqlalchemy
from api.model.chat_complete.conversation import (
ConversationChunkHelper,
ConversationChunkModel,
)
import sys import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
import config import config
import utils.config import utils.config, utils.web
from aiohttp import web from aiohttp import web
from api.model.embedding_search.title_collection import TitleCollectionModel from api.model.embedding_search.title_collection import TitleCollectionModel
@ -18,21 +24,21 @@ from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
class ChatCompleteServicePrepareResponse(TypedDict): class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list extract_doc: list
question_tokens: int question_tokens: int
conversation_id: int
chunk_id: int
class ChatCompleteServiceResponse(TypedDict): class ChatCompleteServiceResponse(TypedDict):
message: str message: str
message_tokens: int message_tokens: int
total_tokens: int total_tokens: int
finish_reason: str finish_reason: str
conversation_id: int question_message_id: str
response_message_id: str
delta_data: dict delta_data: dict
class ChatCompleteService: class ChatCompleteService:
def __init__(self, dbs: DatabaseService, title: str): def __init__(self, dbs: DatabaseService, title: str):
self.dbs = dbs self.dbs = dbs
@ -58,6 +64,7 @@ class ChatCompleteService:
self.question = "" self.question = ""
self.question_tokens: Optional[int] = None self.question_tokens: Optional[int] = None
self.conversation_id: Optional[int] = None self.conversation_id: Optional[int] = None
self.conversation_start_time: Optional[int] = None
self.delta_data = {} self.delta_data = {}
@ -81,22 +88,31 @@ class ChatCompleteService:
async def get_question_tokens(self, question: str): async def get_question_tokens(self, question: str):
return await self.tiktoken.get_tokens(question) return await self.tiktoken.get_tokens(question)
async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, async def prepare_chat_complete(
question_tokens: Optional[int] = None, self,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse: question: str,
conversation_id: Optional[str] = None,
user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
if user_id is not None: if user_id is not None:
user_id = int(user_id) user_id = int(user_id)
self.user_id = user_id self.user_id = user_id
self.question = question self.question = question
self.conversation_start_time = int(time.time())
self.conversation_info = None self.conversation_info = None
if conversation_id is not None: if conversation_id is not None:
self.conversation_id = int(conversation_id) self.conversation_id = int(conversation_id)
self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id) self.conversation_info = await self.conversation_helper.find_by_id(
self.conversation_id
)
else: else:
self.conversation_id = None self.conversation_id = None
if self.conversation_info is not None: if self.conversation_info is not None:
if self.conversation_info.user_id != user_id: if self.conversation_info.user_id != user_id:
raise web.HTTPUnauthorized() raise web.HTTPUnauthorized()
@ -106,97 +122,201 @@ class ChatCompleteService:
else: else:
self.question_tokens = question_tokens self.question_tokens = question_tokens
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and if (
self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS): len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS
and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS
):
# If the question is too long, we need to truncate it # If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge() raise web.HTTPRequestEntityTooLarge()
self.conversation_chunk = None
if self.conversation_info is not None:
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)
if edit_message_id and "," in edit_message_id:
(edit_chunk_id, edit_msg_id) = edit_message_id.split(",")
edit_chunk_id = int(edit_chunk_id)
# Remove overrided conversation chunks
start_overrided = False
should_remove_chunk_ids = []
for chunk_id in chunk_id_list:
if start_overrided:
should_remove_chunk_ids.append(chunk_id)
else:
if chunk_id == edit_chunk_id:
start_overrided = True
if len(should_remove_chunk_ids) > 0:
await self.conversation_chunk_helper.remove(should_remove_chunk_ids)
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
self.conversation_id
)
# Remove outdated message
edit_message_pos = None
old_tokens = 0
for i in range(0, len(self.conversation_chunk.message_data)):
msg_data = self.conversation_chunk.message_data[i]
if msg_data["id"] == edit_msg_id:
edit_message_pos = i
break
if "tokens" in msg_data and msg_data["tokens"]:
old_tokens += msg_data["tokens"]
if edit_message_pos:
self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos]
flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens = old_tokens
else:
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
self.conversation_id
)
# If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
summary, tokens = await self.make_summary(
self.conversation_chunk.message_data
)
new_message_log = [
{
"role": "summary",
"content": summary,
"tokens": tokens,
"time": int(time.time()),
}
]
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_id,
message_data=new_message_log,
tokens=tokens,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
else:
# 创建新对话
title_info = self.embedding_search.title_info
self.conversation_info = ConversationModel(
user_id=self.user_id,
module="chatcomplete",
page_id=title_info["page_id"],
rev_id=title_info["rev_id"],
)
self.conversation_info = await self.conversation_helper.add(
self.conversation_info,
)
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_info.id,
message_data=[],
tokens=0,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(
self.conversation_chunk
)
# Extract document from wiki page index # Extract document from wiki page index
self.extract_doc = None self.extract_doc = None
if embedding_search is not None: if embedding_search is not None:
self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search) self.extract_doc, token_usage = await self.embedding_search.search(
question, **embedding_search
)
if self.extract_doc is not None: if self.extract_doc is not None:
self.question_tokens += token_usage self.question_tokens += token_usage
return ChatCompleteServicePrepareResponse( return ChatCompleteServicePrepareResponse(
extract_doc=self.extract_doc, extract_doc=self.extract_doc,
question_tokens=self.question_tokens question_tokens=self.question_tokens,
conversation_id=self.conversation_info.id,
chunk_id=self.conversation_chunk.id
) )
async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse: async def finish_chat_complete(
self, on_message: Optional[callable] = None
) -> ChatCompleteServiceResponse:
delta_data = {} delta_data = {}
self.conversation_chunk = None
message_log = [] message_log = []
if self.conversation_info is not None: if self.conversation_chunk is not None:
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id)
# If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
summary, tokens = await self.make_summary(self.conversation_chunk.message_data)
new_message_log = [
{"role": "summary", "content": summary, "tokens": tokens}
]
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens)
self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id
message_log = []
for message in self.conversation_chunk.message_data: for message in self.conversation_chunk.message_data:
message_log.append({ message_log.append(
"role": message["role"], {
"content": message["content"], "role": message["role"],
}) "content": message["content"],
}
)
if self.extract_doc is not None: if self.extract_doc is not None:
doc_prompt_content = "\n".join(["%d. %s" % ( doc_prompt_content = "\n".join(
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)]) [
"%d. %s" % (i + 1, doc["markdown"] or doc["text"])
for i, doc in enumerate(self.extract_doc)
]
)
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", { doc_prompt = utils.config.get_prompt(
"content": doc_prompt_content}) "extracted_doc", "prompt", {"content": doc_prompt_content}
)
message_log.append({"role": "user", "content": doc_prompt}) message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system_prompt") system_prompt = utils.config.get_prompt("chat", "system_prompt")
# Start chat complete # Start chat complete
if on_message is not None: if on_message is not None:
response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message) response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, message_log, on_message
)
else: else:
response = await self.openai_api.chat_complete(self.question, system_prompt, message_log) response = await self.openai_api.chat_complete(
self.question, system_prompt, message_log
if self.conversation_info is None: )
# Create a new conversation
message_log_list = [ description = response["message"][0:150]
{"role": "user", "content": self.question, "tokens": self.question_tokens},
{"role": "assistant", question_msg_id = utils.web.generate_uuid()
"content": response["message"], "tokens": response["message_tokens"]}, response_msg_id = utils.web.generate_uuid()
]
title = None new_message_data = [
try: {
title, token_usage = await self.make_title(message_log_list) "id": question_msg_id,
delta_data["title"] = title "role": "user",
except Exception as e: "content": self.question,
title = config.CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE "tokens": self.question_tokens,
print(str(e), file=sys.stderr) "time": self.conversation_start_time,
traceback.print_exc(file=sys.stderr) },
{
"id": response_msg_id,
"role": "assistant",
"content": response["message"],
"tokens": response["message_tokens"],
"time": int(time.time()),
},
]
if self.conversation_info is not None:
total_token_usage = self.question_tokens + response["message_tokens"] total_token_usage = self.question_tokens + response["message_tokens"]
# Generate title if not exists
title_info = self.embedding_search.title_info if self.conversation_info.title is None:
self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title) title = None
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage) try:
else: title, token_usage = await self.make_title(new_message_data)
# Update the conversation chunk delta_data["title"] = title
await self.conversation_helper.refresh_updated_at(self.conversation_id) except Exception as e:
print(str(e), file=sys.stderr)
self.conversation_chunk.message_data.append( traceback.print_exc(file=sys.stderr)
{"role": "user", "content": self.question, "tokens": self.question_tokens})
self.conversation_chunk.message_data.append( self.conversation_info.title = title
{"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]})
# Update conversation info
self.conversation_info.description = description
await self.conversation_helper.update(self.conversation_info)
# Update conversation chunk
self.conversation_chunk.message_data.extend(new_message_data)
flag_modified(self.conversation_chunk, "message_data") flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens += self.question_tokens + \ self.conversation_chunk.tokens += total_token_usage
response["message_tokens"]
await self.conversation_chunk_helper.update(self.conversation_chunk) await self.conversation_chunk_helper.update(self.conversation_chunk)
@ -205,17 +325,18 @@ class ChatCompleteService:
message_tokens=response["message_tokens"], message_tokens=response["message_tokens"],
total_tokens=response["total_tokens"], total_tokens=response["total_tokens"],
finish_reason=response["finish_reason"], finish_reason=response["finish_reason"],
conversation_id=self.conversation_info.id, question_message_id=question_msg_id,
delta_data=delta_data response_message_id=response_msg_id,
delta_data=delta_data,
) )
async def set_latest_point_cost(self, point_cost: int) -> bool: async def set_latest_point_cost(self, point_cost: int) -> bool:
if self.conversation_chunk is None: if self.conversation_chunk is None:
return False return False
if len(self.conversation_chunk.message_data) == 0: if len(self.conversation_chunk.message_data) == 0:
return False return False
for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
if self.conversation_chunk.message_data[i]["role"] == "assistant": if self.conversation_chunk.message_data[i]["role"] == "assistant":
self.conversation_chunk.message_data[i]["point_cost"] = point_cost self.conversation_chunk.message_data[i]["point_cost"] = point_cost
@ -224,44 +345,50 @@ class ChatCompleteService:
return True return True
async def make_summary(self, message_log_list: list) -> tuple[str, int]: async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = [] chat_log: list[str] = []
for message_data in message_log_list: for message_data in message_log_list:
if message_data["role"] == 'summary': if message_data["role"] == "summary":
chat_log.append(message_data["content"]) chat_log.append(message_data["content"])
elif message_data["role"] == 'assistant': elif message_data["role"] == "assistant":
chat_log.append( chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}') f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}'
)
else: else:
chat_log.append(f'User: {message_data["content"]}') chat_log.append(f'User: {message_data["content"]}')
chat_log_str = '\n'.join(chat_log) chat_log_str = "\n".join(chat_log)
summary_system_prompt = utils.config.get_prompt( summary_system_prompt = utils.config.get_prompt("summary", "system_prompt")
"summary", "system_prompt")
summary_prompt = utils.config.get_prompt( summary_prompt = utils.config.get_prompt(
"summary", "prompt", {"content": chat_log_str}) "summary", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(summary_prompt, summary_system_prompt) response = await self.openai_api.chat_complete(
summary_prompt, summary_system_prompt
)
return response["message"], response["message_tokens"] return response["message"], response["message_tokens"]
async def make_title(self, message_log_list: list) -> tuple[str, int]: async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = [] chat_log: list[str] = []
for message_data in message_log_list: for message_data in message_log_list:
if message_data["role"] == 'assistant': if message_data["role"] == "assistant":
chat_log.append( chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}') f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}'
elif message_data["role"] == 'user': )
elif message_data["role"] == "user":
chat_log.append(f'User: {message_data["content"]}') chat_log.append(f'User: {message_data["content"]}')
chat_log_str = '\n'.join(chat_log) chat_log_str = "\n".join(chat_log)
title_system_prompt = utils.config.get_prompt("title", "system_prompt") title_system_prompt = utils.config.get_prompt("title", "system_prompt")
title_prompt = utils.config.get_prompt( title_prompt = utils.config.get_prompt(
"title", "prompt", {"content": chat_log_str}) "title", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(title_prompt, title_system_prompt) response = await self.openai_api.chat_complete(
title_prompt, title_system_prompt
)
return response["message"], response["message_tokens"] return response["message"], response["message_tokens"]

@ -9,12 +9,17 @@ from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences from utils.wiki import getWikiSentences
class EmbeddingRunningException(Exception):
pass
class EmbeddingSearchArgs(TypedDict): class EmbeddingSearchArgs(TypedDict):
limit: Optional[int] limit: Optional[int]
in_collection: Optional[bool] in_collection: Optional[bool]
distance_limit: Optional[float] distance_limit: Optional[float]
class EmbeddingSearchService: class EmbeddingSearchService:
indexing_page_ids: list[int] = []
def __init__(self, dbs: DatabaseService, title: str): def __init__(self, dbs: DatabaseService, title: str):
self.dbs = dbs self.dbs = dbs
@ -92,6 +97,9 @@ class EmbeddingSearchService:
self.page_id = self.page_info["pageid"] self.page_id = self.page_info["pageid"]
if self.page_id in self.indexing_page_ids:
raise EmbeddingRunningException("Page index is running now")
# Create collection # Create collection
self.collection_info = await self.title_collection.find_by_title(self.base_title) self.collection_info = await self.title_collection.find_by_title(self.base_title)
if self.collection_info is None: if self.collection_info is None:
@ -129,7 +137,6 @@ class EmbeddingSearchService:
if self.unindexed_docs is None: if self.unindexed_docs is None:
return False return False
chunk_limit = 500 chunk_limit = 500
chunk_len = 0 chunk_len = 0

@ -1,3 +1,4 @@
from __future__ import annotations
import json import json
import sys import sys
import time import time
@ -15,24 +16,49 @@ class MediaWikiApiException(Exception):
def __str__(self) -> str: def __str__(self) -> str:
return self.info return self.info
class MediaWikiPageNotFoundException(MediaWikiApiException): class MediaWikiPageNotFoundException(Exception):
pass def __init__(self, info: str, code: Optional[str] = None) -> None:
super().__init__(info)
self.info = info
self.code = code
self.message = self.info
def __str__(self) -> str:
return self.info
class MediaWikiUserNoEnoughPointsException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:
super().__init__(info)
self.info = info
self.code = code
self.message = self.info
def __str__(self) -> str:
return self.info
class ChatCompleteGetPointUsageResponse(TypedDict):
point_cost: int
class ChatCompleteReportUsageResponse(TypedDict): class ChatCompleteReportUsageResponse(TypedDict):
point_cost: int point_cost: int
transaction_id: str transaction_id: str
class MediaWikiApi: class MediaWikiApi:
cookie_jar = aiohttp.CookieJar(unsafe=True) instance: MediaWikiApi = None
@staticmethod @staticmethod
def create(): def create():
return MediaWikiApi(config.MW_API) if MediaWikiApi.instance is None:
MediaWikiApi.instance = MediaWikiApi(config.MW_API)
return MediaWikiApi.instance
def __init__(self, api_url: str): def __init__(self, api_url: str):
self.api_url = api_url self.api_url = api_url
self.login_time = 0.0
self.cookie_jar = aiohttp.CookieJar(unsafe=True)
self.login_identity = None self.login_identity = None
self.login_time = 0.0
async def get_page_info(self, title: str): async def get_page_info(self, title: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -50,7 +76,7 @@ class MediaWikiApi:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
if "missing" in data["query"]["pages"][0]: if "missing" in data["query"]["pages"][0]:
raise MediaWikiPageNotFoundException() raise MediaWikiPageNotFoundException(data["error"]["info"], data["error"]["code"])
return data["query"]["pages"][0] return data["query"]["pages"][0]
@ -98,7 +124,21 @@ class MediaWikiApi:
ret["user"] = data["query"]["userinfo"]["name"] ret["user"] = data["query"]["userinfo"]["name"]
return ret return ret
async def is_logged_in(self,):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
"action": "query",
"format": "json",
"formatversion": "2",
"meta": "userinfo"
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return data["query"]["userinfo"]["id"] != 0
async def get_token(self, token_type: str): async def get_token(self, token_type: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -145,8 +185,10 @@ class MediaWikiApi:
async def refresh_login(self): async def refresh_login(self):
if self.login_identity is None: if self.login_identity is None:
print("刷新MW机器人账号登录状态失败没有保存的用户")
return False return False
if time.time() - self.login_time > 30: if time.time() - self.login_time > 3600:
print("刷新MW机器人账号登录状态")
return await self.robot_login(self.login_identity["username"], self.login_identity["password"]) return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
async def chat_complete_user_info(self, user_id: int): async def chat_complete_user_info(self, user_id: int):
@ -169,6 +211,32 @@ class MediaWikiApi:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return data["chatcompletebot"]["userinfo"] return data["chatcompletebot"]["userinfo"]
async def chat_complete_get_point_cost(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteGetPointUsageResponse:
await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
post_data = {
"action": "chatcompletebot",
"method": "reportusage",
"step": "check",
"userid": int(user_id) if user_id is not None else None,
"useraction": user_action,
"tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None,
"format": "json",
"formatversion": "2",
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
print(data)
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
point_cost = int(data["chatcompletebot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteGetPointUsageResponse(point_cost=point_cost)
async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse: async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse:
await self.refresh_login() await self.refresh_login()
@ -178,10 +246,10 @@ class MediaWikiApi:
"action": "chatcompletebot", "action": "chatcompletebot",
"method": "reportusage", "method": "reportusage",
"step": "start", "step": "start",
"userid": int(user_id), "userid": int(user_id) if user_id is not None else None,
"useraction": user_action, "useraction": user_action,
"tokens": int(tokens), "tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines), "extractlines": int(extractlines) if extractlines is not None else None,
"format": "json", "format": "json",
"formatversion": "2", "formatversion": "2",
} }
@ -190,10 +258,13 @@ class MediaWikiApi:
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
print(data) if data["error"]["code"] == "noenoughpoints":
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiUserNoEnoughPointsException(data["error"]["info"], data["error"]["info"])
else:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return ChatCompleteReportUsageResponse(point_cost=int(data["chatcompletebot"]["reportusage"]["pointcost"]), point_cost = int(data["chatcompletebot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteReportUsageResponse(point_cost=point_cost,
transaction_id=data["chatcompletebot"]["reportusage"]["transactionid"]) transaction_id=data["chatcompletebot"]["reportusage"]["transactionid"])
async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None): async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
@ -237,7 +308,7 @@ class MediaWikiApi:
} }
# Filter out None values # Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.get(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
from typing import Callable, Optional, TypedDict from typing import Callable, Optional, TypedDict
@ -36,20 +37,20 @@ class OpenAIApi:
def build_header(self): def build_header(self):
if config.OPENAI_API_TYPE == "azure": if config.OPENAI_API_TYPE == "azure":
return { return {
"Content-Type": "application/json", "content-type": "application/json",
"Accept": "application/json", "accept": "application/json",
"api-key": self.api_key "api-key": self.api_key
} }
else: else:
return { return {
"Authorization": f"Bearer {self.api_key}", "authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "content-type": "application/json",
"Accept": "application/json", "accept": "application/json",
} }
def get_url(self, method: str): def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure": if config.OPENAI_API_TYPE == "azure":
if method == "completions": if method == "chat/completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method
elif method == "embeddings": elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method
@ -92,26 +93,41 @@ class OpenAIApi:
if config.OPENAI_API_TYPE == "azure": if config.OPENAI_API_TYPE == "azure":
# Azure api does not support batch # Azure api does not support batch
for index, text in enumerate(text_list): for index, text in enumerate(text_list):
async with session.post(url, retry_num = 0
headers=self.build_header(), max_retry_num = 3
params=params, while retry_num < max_retry_num:
json={"input": text}, try:
timeout=30, async with session.post(url,
proxy=config.REQUEST_PROXY) as resp: headers=self.build_header(),
params=params,
data = await resp.json() json={"input": text},
timeout=30,
one_data = data["data"] proxy=config.REQUEST_PROXY) as resp:
if len(one_data) > 0:
embedding = one_data[0]["embedding"] data = await resp.json()
if embedding is not None:
embedding = np.array(embedding) one_data = data["data"]
doc_list[index]["embedding"] = embedding if len(one_data) > 0:
embedding = one_data[0]["embedding"]
token_usage += int(data["usage"]["total_tokens"]) if embedding is not None:
embedding = np.array(embedding)
if on_index_progress is not None: doc_list[index]["embedding"] = embedding
await on_index_progress(index, len(text_list))
token_usage += int(data["usage"]["total_tokens"])
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
break
except Exception as e:
retry_num += 1
if retry_num >= max_retry_num:
raise e
print("Error: %s" % e)
print("Retrying...")
await asyncio.sleep(0.5)
else: else:
async with session.post(url, async with session.post(url,
headers=self.build_header(), headers=self.build_header(),
@ -158,7 +174,7 @@ class OpenAIApi:
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None): async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation) messageList = await self.make_message_list(question, system_prompt, conversation)
url = self.get_url("completions") url = self.get_url("chat/completions")
params = {} params = {}
post_data = { post_data = {
@ -175,7 +191,7 @@ class OpenAIApi:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(url, async with session.post(url,
headers=self.build_header, headers=self.build_header(),
params=params, params=params,
json=post_data, json=post_data,
timeout=30, timeout=30,
@ -210,24 +226,38 @@ class OpenAIApi:
for message in messageList: for message in messageList:
prompt_tokens += await tiktoken.get_tokens(message["content"]) prompt_tokens += await tiktoken.get_tokens(message["content"])
params = { url = self.get_url("chat/completions")
"model": "gpt-3.5-turbo",
params = {}
post_data = {
"messages": messageList, "messages": messageList,
"stream": True,
"user": user, "user": user,
"stream": True,
"n": 1,
"max_tokens": 768,
"stop": None,
"temperature": 1,
"top_p": 0.95
} }
params = {k: v for k, v in params.items() if v is not None}
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
post_data = {k: v for k, v in post_data.items() if v is not None}
res_message: list[str] = [] res_message: list[str] = []
finish_reason = None finish_reason = None
async with sse_client.EventSource( async with sse_client.EventSource(
self.api_url + "/v1/chat/completions", url,
option={ option={
"method": "POST" "method": "POST"
}, },
headers={"Authorization": f"Bearer {self.api_key}"}, headers=self.build_header(),
json=params, params=params,
json=post_data,
proxy=config.REQUEST_PROXY proxy=config.REQUEST_PROXY
) as session: ) as session:
async for event in session: async for event in session:
@ -238,11 +268,13 @@ class OpenAIApi:
[DONE] [DONE]
""" """
content_started = False content_started = False
if event.data == "[DONE]": event_data = event.data.strip()
if event_data == "[DONE]":
break break
elif event.data[0] == "{" and event.data[-1] == "}": elif event_data[0] == "{" and event_data[-1] == "}":
data = json.loads(event.data) data = json.loads(event_data)
if "choices" in data and len(data["choices"]) > 0: if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0] choice = data["choices"][0]
@ -262,11 +294,14 @@ class OpenAIApi:
res_message.append(delta_message) res_message.append(delta_message)
if config.DEBUG: # if config.DEBUG:
print(delta_message, end="", flush=True) # print(delta_message, end="", flush=True)
if on_message is not None: if on_message is not None:
await on_message(delta_message) await on_message(delta_message)
if finish_reason is not None:
break
res_message_str = "".join(res_message) res_message_str = "".join(res_message)
message_tokens = await tiktoken.get_tokens(res_message_str) message_tokens = await tiktoken.get_tokens(res_message_str)

@ -1,50 +0,0 @@
from __future__ import annotations
import asyncpg
class SimpleQueryBuilder:
def __init__(self):
self._table_name = ""
self._select = ["*"]
self._where = []
self._having = []
self._order_by = None
self._order_by_desc = False
def table(self, table_name: str):
self._table_name = table_name
return self
def fields(self, fields: list[str]):
self.select = fields
return self
def where(self, where: str, condition: str, param):
self._where.append((where, condition, param))
return self
def having(self, having: str, condition: str, param):
self._having.append((having, condition, param))
return self
def build(self):
sql = "SELECT %s FROM %s" % (", ".join(self._select), self._table_name)
params = []
paramsLen = 0
if len(self._where) > 0:
sql += " WHERE "
for where, condition, param in self._where:
params.append(param)
paramsLen += 1
sql += "%s %s $%d AND " % (where, condition, paramsLen)
if self._order_by is not None:
sql += " ORDER BY %s %s" % (self._order_by, "DESC" if self._order_by_desc else "ASC")
if len(self._having) > 0:
sql += " HAVING "
for having, condition, param in self._having:
params.append(param)
paramsLen += 1
sql += "%s %s $%d AND " % (having, condition, paramsLen)

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))

@ -0,0 +1,27 @@
import asyncio
import time
import base
from local import loop, noawait
from service.bert_embedding import bert_embedding_queue
async def main():
embedding_list = []
start_time = time.time()
queue = []
with open("test/test.md", "r", encoding="utf-8") as fp:
text = fp.read()
lines = text.split("\n")
for line in lines:
line = line.strip()
if line == "":
continue
queue.append(bert_embedding_queue.get_embeddings(line))
embedding_list = await asyncio.gather(*queue)
end_time = time.time()
print("time cost: %.4f" % (end_time - start_time))
print("dimensions: %d" % len(embedding_list[0]))
await noawait.end()
if __name__ == '__main__':
loop.run_until_complete(main())

@ -0,0 +1,37 @@
import traceback
import base
import local
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.tiktoken import TikTokenService
async def main():
dbs = await DatabaseService.create()
tiktoken = await TikTokenService.create()
async with ChatCompleteService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as chat_complete:
question = "你是谁?"
question_tokens = await tiktoken.get_tokens(question)
try:
prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, {
"distance_limit": 0.6,
"limit": 10
})
print(prepare_res)
async def on_message(message: str):
# print(message)
pass
res = await chat_complete.finish_chat_complete(on_message)
print(res)
except Exception as err:
print(err)
traceback.print_exc()
await local.noawait.end()
if __name__ == '__main__':
local.loop.run_until_complete(main())

@ -1,37 +1,38 @@
import local import base
from service.database import DatabaseService import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
from service.embedding_search import EmbeddingSearchService
async def main():
dbs = await DatabaseService.create() async def main():
dbs = await DatabaseService.create()
async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search:
await embedding_search.prepare_update_index() async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search:
await embedding_search.prepare_update_index()
async def on_index_progress(current, length):
print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True) async def on_index_progress(current, length):
print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True)
print("")
await embedding_search.update_page_index(on_index_progress) print("")
print("") await embedding_search.update_page_index(on_index_progress)
print("")
while True:
query = input("请输入要搜索的问题 (.exit 退出)") while True:
if query == ".exit": query = input("请输入要搜索的问题 (.exit 退出)")
break if query == ".exit":
res, token_usage = await embedding_search.search(query, 5) break
total_length = 0 res, token_usage = await embedding_search.search(query, 5)
if res: total_length = 0
for one in res: if res:
total_length += len(one["markdown"]) for one in res:
print("%s, distance=%.4f" % (one["markdown"], one["distance"])) total_length += len(one["markdown"])
else: print("%s, distance=%.4f" % (one["markdown"], one["distance"]))
print("未搜索到相关内容") else:
print("未搜索到相关内容")
print("总长度:%d" % total_length)
print("总长度:%d" % total_length)
await local.noawait.end()
await local.noawait.end()
if __name__ == '__main__':
if __name__ == '__main__':
local.loop.run_until_complete(main()) local.loop.run_until_complete(main())

@ -0,0 +1,20 @@
import asyncio
import base
from local import loop, noawait
async def test_timer1():
print("timer1")
async def test_timer2():
print("timer2")
async def main():
timer_id = noawait.add_timer(test_timer1, 1)
timer_id = noawait.add_timer(test_timer2, 2)
print("Timer id: %d" % timer_id)
while True:
await asyncio.sleep(1)
await noawait.end()
if __name__ == '__main__':
loop.run_until_complete(main())

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from functools import wraps from functools import wraps
import json
from typing import Any, Optional, Dict from typing import Any, Optional, Dict
from aiohttp import web from aiohttp import web
import jwt import jwt
@ -8,22 +9,32 @@ import uuid
ParamRule = Dict[str, Any] ParamRule = Dict[str, Any]
class ParamInvalidException(Exception): class ParamInvalidException(web.HTTPBadRequest):
def __init__(self, param_list: list[str], rules: dict[str, ParamRule]): def __init__(self, param_list: list[str], rules: dict[str, ParamRule]):
self.code = "param_invalid" self.code = "param_invalid"
self.param_list = param_list self.param_list = param_list
self.rules = rules self.rules = rules
param_list_str = "'" + ("', '".join(param_list)) + "'" param_list_str = "'" + ("', '".join(param_list)) + "'"
super().__init__(f"Param invalid: {param_list_str}") super().__init__(f"Param invalid: {param_list_str}",
content_type="application/json",
body=json.dumps({
"status": -1,
"error": {
"code": self.code,
"message": f"Param invalid: {param_list_str}"
}
}))
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None): async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
params: dict[str, Any] = {} params: dict[str, Any] = {}
for key, value in request.match_info.items():
params[key] = value
for key, value in request.query.items(): for key, value in request.query.items():
params[key] = value params[key] = value
if request.method == 'POST': if request.method == 'POST':
if request.headers.get('content-type') == 'application/json': if request.headers.get('content-type') == 'application/json':
data = await request.json() data = await request.json()
if data is not None and data is dict: if data is not None and isinstance(data, dict):
for key, value in data.items(): for key, value in data.items():
params[key] = value params[key] = value
else: else:
@ -34,7 +45,7 @@ async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]]
if rules is not None: if rules is not None:
invalid_params: list[str] = [] invalid_params: list[str] = []
for key, rule in rules.items(): for key, rule in rules.items():
if "required" in rule and rule["required"] and params[key] is None: if "required" in rule and rule["required"] and key not in params.keys():
invalid_params.append(key) invalid_params.append(key)
continue continue
@ -50,9 +61,20 @@ async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]]
elif rule["type"] == float: elif rule["type"] == float:
params[key] = float(params[key]) params[key] = float(params[key])
elif rule["type"] == bool: elif rule["type"] == bool:
val = params[key].lower() val = params[key]
if val == "false" or val == "0": if isinstance(val, bool):
params[key] = False params[key] = val
elif isinstance(val, str):
val = val.lower()
if val.lower() == "false" or val == "0":
params[key] = False
else:
params[key] = True
elif isinstance(val, int):
if val == 0:
params[key] = False
else:
params[key] = True
else: else:
params[key] = True params[key] = True
except ValueError: except ValueError:

Loading…
Cancel
Save