diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index 5685c12..150ac1e 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -1,19 +1,19 @@ from __future__ import annotations import asyncio -import json import sys import time import traceback +from api.model.toolkit_ui.conversation import ConversationHelper from local import noawait from typing import Optional, Callable, TypedDict from aiohttp import web from sqlalchemy import select -from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel +from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from noawait import NoAwaitPool -from service.chat_complete import ChatCompleteService +from service.chat_complete import ChatCompleteService, ChatCompleteServiceResponse from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs -from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException +from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.tiktoken import TikTokenService import utils.web @@ -23,6 +23,7 @@ class ChatCompleteTask: def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False): self.task_id = utils.web.generate_uuid() self.on_message: list[Callable] = [] + self.on_finished: list[Callable] = [] self.on_error: list[Callable] = [] self.chunks: list[str] = [] @@ -37,7 +38,12 @@ class ChatCompleteTask: self.transatcion_id: Optional[str] = None self.point_cost: int = 0 - async def init(self, question: str, conversation_id: Optional[str] = None, + self.is_finished = False + self.finished_time: Optional[float] = None + self.result: Optional[ChatCompleteServiceResponse] = None + self.error: Optional[Exception] = None + + async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None): self.tiktoken = await TikTokenService.create() @@ -54,34 +60,58 @@ class ChatCompleteTask: self.transatcion_id: Optional[str] = None self.point_cost: int = 0 if not self.is_system: - usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit) - self.transatcion_id = usage_res.get("transaction_id") - self.point_cost = usage_res.get("point_cost") - - chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, - user_id=self.user_id, embedding_search=embedding_search) + usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", + question_tokens, extract_limit) + self.transatcion_id = usage_res["transaction_id"] + self.point_cost = usage_res["point_cost"] - return chat_res + res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, + user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search) + + return res else: await self._exit() raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) async def _on_message(self, delta_message: str): + self.chunks.append(delta_message) + for callback in self.on_message: - await callback(delta_message) + try: + await callback(delta_message) + except Exception as e: + print("Error while processing on_message callback: %s" % e, file=sys.stderr) + traceback.print_exc() + + async def _on_finished(self): + for callback in self.on_finished: + try: + await callback(self.result) + except Exception as e: + print("Error while processing on_finished callback: %s" % e, file=sys.stderr) + traceback.print_exc() async def _on_error(self, err: Exception): + self.error = err for callback in self.on_error: - await callback(err) + try: + await callback(err) + except Exception as e: + print("Error while processing on_error callback: %s" % e, file=sys.stderr) + traceback.print_exc() async def run(self): try: - chat_res = self.chat_complete.finish_chat_complete(self._on_message) + chat_res = await self.chat_complete.finish_chat_complete(self._on_message) await self.chat_complete.set_latest_point_cost(self.point_cost) + self.result = chat_res + if self.transatcion_id: - result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"]) + await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"]) + + await self._on_finished() except Exception as e: err_msg = f"Error while processing chat complete request: {e}" @@ -89,7 +119,7 @@ class ChatCompleteTask: traceback.print_exc() if self.transatcion_id: - result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg) + await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg) await self._on_error(e) finally: @@ -98,10 +128,19 @@ class ChatCompleteTask: async def _exit(self): await self.chat_complete_service.__aexit__(None, None, None) del chat_complete_tasks[self.task_id] + self.is_finished = True + self.finished_time = time.time() - @noawait.wrap - async def start(self): - await self.run() +TASK_EXPIRE_TIME = 60 * 10 + +async def chat_complete_task_gc(): + now = time.time() + for task_id in chat_complete_tasks.keys(): + task = chat_complete_tasks[task_id] + if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME: + del chat_complete_tasks[task_id] + +noawait.add_timer(chat_complete_task_gc, 60) class ChatComplete: @staticmethod @@ -112,7 +151,7 @@ class ChatComplete: "required": False, "type": int }, - "conversation_id": { + "id": { "required": True, "type": int } @@ -123,40 +162,31 @@ class ChatComplete: else: user_id = params.get("user_id") - conversation_id = params.get("conversation_id") + conversation_id = params.get("id") db = await DatabaseService.create(request.app) - async with db.create_session() as session: - stmt = select(ConversationModel).where( - ConversationModel.id == conversation_id) - - conversation_data = await session.scalar(stmt) + async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper: + conversation_info = await conversation_helper.find_by_id(conversation_id) - if conversation_data is None: + if conversation_info is None: return await utils.web.api_response(-1, error={ "code": "conversation-not-found", "message": "Conversation not found." }, http_status=404, request=request) - if conversation_data.user_id != user_id: + if conversation_info.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) - stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \ - .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc()) - - conversation_chunk_result = await session.scalars(stmt) + conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id) conversation_chunk_list = [] for result in conversation_chunk_result: - conversation_chunk_list.append({ - "id": result.id, - "updated_at": result.updated_at - }) + conversation_chunk_list.append(result) return await utils.web.api_response(1, conversation_chunk_list, request=request) @@ -181,26 +211,32 @@ class ChatComplete: chunk_id = params.get("chunk_id") - dbs = await DatabaseService.create(request.app) - async with dbs.create_session() as session: - stmt = select(ConversationChunkModel).where( - ConversationChunkModel.id == chunk_id) - - conversation_data = await session.scalar(stmt) - - if conversation_data is None: + db = await DatabaseService.create(request.app) + async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper: + chunk_info = await conversation_chunk_helper.find_by_id(chunk_id) + if chunk_info is None: return await utils.web.api_response(-1, error={ "code": "conversation-chunk-not-found", "message": "Conversation chunk not found." }, http_status=404, request=request) - if conversation_data.conversation.user_id != user_id: + conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id) + + if conversation_info is not None and conversation_info.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) - return await utils.web.api_response(1, conversation_data.__dict__, request=request) + chunk_dict = { + "id": chunk_info.id, + "conversation_id": chunk_info.conversation_id, + "message_data": chunk_info.message_data, + "tokens": chunk_info.tokens, + "updated_at": chunk_info.updated_at, + } + + return await utils.web.api_response(1, chunk_dict, request=request) @staticmethod @utils.web.token_auth @@ -219,6 +255,47 @@ class ChatComplete: return await utils.web.api_response(1, {"tokens": tokens}, request=request) + @staticmethod + @utils.web.token_auth + async def get_point_cost(request: web.Request): + params = await utils.web.get_param(request, { + "question": { + "type": str, + "required": True, + }, + "extract_limit": { + "type": int, + "required": False, + "default": 10, + }, + }) + + user_id = request.get("user") + caller = request.get("caller") + + question = params.get("question") + extract_limit = params.get("extract_limit") + + tiktoken = await TikTokenService.create() + mwapi = MediaWikiApi.create() + + tokens = await tiktoken.get_tokens(question) + + try: + res = await mwapi.chat_complete_get_point_cost(user_id, "chatcomplete", tokens, extract_limit) + return await utils.web.api_response(1, { + "point_cost": res["point_cost"], + "tokens": tokens, + }, request=request) + except Exception as e: + err_msg = f"Error while get chat complete point cost: {e}" + traceback.print_exc() + + return await utils.web.api_response(-1, error={ + "code": "chat-complete-error", + "message": err_msg + }, http_status=500, request=request) + @staticmethod @utils.web.token_auth async def start_chat_complete(request: web.Request): @@ -245,6 +322,10 @@ class ChatComplete: "required": False, "default": False, }, + "edit_message_id": { + "type": str, + "required": False, + }, }) user_id = request.get("user") @@ -257,20 +338,25 @@ class ChatComplete: extract_limit = params.get("extract_limit") in_collection = params.get("in_collection") + edit_message_id = params.get("edit_message_id") + dbs = await DatabaseService.create(request.app) try: chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user") - init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={ + init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id, + embedding_search={ "limit": extract_limit, "in_collection": in_collection, }) chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task - chat_complete_task.start() + noawait.add_task(chat_complete_task.run()) - return utils.web.api_response(1, data={ + return await utils.web.api_response(1, data={ + "conversation_id": init_res["conversation_id"], + "chunk_id": init_res["chunk_id"], "question_tokens": init_res["question_tokens"], "extract_doc": init_res["extract_doc"], "task_id": chat_complete_task.task_id, @@ -282,6 +368,12 @@ class ChatComplete: "title": page_title, "message": error_msg }, http_status=404, request=request) + except MediaWikiUserNoEnoughPointsException as e: + error_msg = "Does not have enough points." % user_id + return await utils.web.api_response(-1, error={ + "code": "no-enough-points", + "message": error_msg + }, http_status=403, request=request) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() @@ -294,4 +386,123 @@ class ChatComplete: @staticmethod @utils.web.token_auth async def chat_complete_stream(request: web.Request): - pass \ No newline at end of file + if not utils.web.is_websocket(request): + return await utils.web.api_response(-1, error={ + "code": "websocket-required", + "message": "This API only accept websocket connection." + }, http_status=400, request=request) + + params = await utils.web.get_param(request, { + "task_id": { + "type": str, + "required": True, + } + }) + + ws = web.WebSocketResponse() + await ws.prepare(request) + + task_id = params.get("task_id") + + task = chat_complete_tasks.get(task_id) + if task is None: + await ws.send_json({ + 'event': 'error', + 'status': -1, + 'message': "Task not found.", + 'error': { + 'code': "task-not-found", + 'info': "Task not found.", + }, + }) + return + + if request.get("caller") == "user": + user_id = request.get("user") + if task.user_id != user_id: + await ws.send_json({ + 'event': 'error', + 'status': -1, + 'message': "Permission denied.", + 'error': { + 'code': "permission-denied", + 'info': "Permission denied.", + }, + }) + return + + if task.is_finished: + if task.error is not None: + await ws.send_json({ + 'event': 'error', + 'status': -1, + 'message': str(task.error), + 'error': { + 'code': "internal-error", + 'info': str(task.error), + }, + }) + await ws.close() + elif task.result is not None: + await ws.send_json({ + 'event': 'connected', + 'status': 1, + 'outputed_message': "".join(task.chunks), + }) + await ws.send_json({ + 'event': 'finished', + 'status': 1, + 'result': task.result + }) + await ws.close() + else: + async def on_message(delta_message: str): + await ws.send_str("+" + delta_message) + + async def on_finished(result: ChatCompleteServiceResponse): + ignored_keys = ["message"] + response_result = { + "point_cost": task.point_cost, + } + for k, v in result.items(): + if k not in ignored_keys: + response_result[k] = v + await ws.send_json({ + 'event': 'finished', + 'status': 1, + 'result': response_result + }) + + await ws.close() + + async def on_error(err: Exception): + await ws.send_json({ + 'event': 'error', + 'status': -1, + 'message': str(err), + 'error': { + 'code': "internal-error", + 'info': str(err), + }, + }) + + await ws.close() + + task.on_message.append(on_message) + task.on_finished.append(on_finished) + task.on_error.append(on_error) + + # Send received message + await ws.send_json({ + 'event': 'connected', + 'status': 1, + 'outputed_message': "".join(task.chunks), + }) + + while True: + if ws.closed: + task.on_message.remove(on_message) + task.on_finished.remove(on_finished) + task.on_error.remove(on_error) + break + await asyncio.sleep(0.1) diff --git a/api/controller/EmbeddingSearch.py b/api/controller/EmbeddingSearch.py index 5873555..ff75a15 100644 --- a/api/controller/EmbeddingSearch.py +++ b/api/controller/EmbeddingSearch.py @@ -2,7 +2,7 @@ import sys import traceback from aiohttp import web from service.database import DatabaseService -from service.embedding_search import EmbeddingSearchService +from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException import utils.web @@ -87,6 +87,18 @@ class EmbeddingSearch: }) if transatcion_id: await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg) + except EmbeddingRunningException: + error_msg = "Page index is running now" + await ws.send_json({ + 'event': 'error', + 'status': -4, + 'message': error_msg, + 'error': { + 'code': 'page_index_running', + }, + }) + if transatcion_id: + await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg) except Exception as e: error_msg = str(e) print(error_msg, file=sys.stderr) @@ -94,7 +106,10 @@ class EmbeddingSearch: await ws.send_json({ 'event': 'error', 'status': -1, - 'message': error_msg + 'message': error_msg, + 'error': { + 'code': 'internal_server_error', + } }) if transatcion_id: await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg) @@ -145,6 +160,16 @@ class EmbeddingSearch: "info": e.info, "message": error_msg }, http_status=500) + except EmbeddingRunningException: + error_msg = "Page index is running now" + + if transatcion_id: + await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg) + + return await utils.web.api_response(-4, error={ + "code": "page-index-running", + "message": error_msg + }, http_status=429) except Exception as e: error_msg = str(e) diff --git a/api/controller/Index.py b/api/controller/Index.py index 32f13b7..54f8ced 100644 --- a/api/controller/Index.py +++ b/api/controller/Index.py @@ -2,7 +2,6 @@ import sys import time import traceback from aiohttp import web -from sqlalchemy import select from api.model.toolkit_ui.conversation import ConversationHelper from api.model.toolkit_ui.page_title import PageTitleHelper from service.database import DatabaseService @@ -27,7 +26,7 @@ class Index: db = await DatabaseService.create(request.app) async with PageTitleHelper(db) as page_title_helper: title_info = await page_title_helper.find_by_title(title) - + if title_info is not None and time.time() - title_info.updated_at < 60: return await utils.web.api_response(1, { "cached": True, @@ -123,6 +122,7 @@ class Index: "id": result.id, "module": result.module, "title": result.title, + "description": result.description, "thumbnail": result.thumbnail, "rev_id": result.rev_id, "updated_at": result.updated_at, @@ -166,6 +166,7 @@ class Index: "id": conversation_info.id, "module": conversation_info.module, "title": conversation_info.title, + "description": conversation_info.description, "thumbnail": conversation_info.thumbnail, "rev_id": conversation_info.rev_id, "updated_at": conversation_info.updated_at, @@ -178,14 +179,68 @@ class Index: @staticmethod @utils.web.token_auth async def remove_conversation(request: web.Request): + params = await utils.web.get_param(request, { + "id": { + "type": int + }, + "ids": { + "type": str + } + }) + + conversation_id = params.get("id") + conversation_ids = params.get("ids") + + if conversation_id is None and conversation_ids is None: + return await utils.web.api_response(-2, error={ + "code": "invalid-params", + "message": "Invalid params." + }, request=request, http_status=400) + + if conversation_id is not None: + conversation_ids = [conversation_id] + else: + conversation_ids = conversation_ids.split(",") + conversation_ids = [int(id) for id in conversation_ids] + + db = await DatabaseService.create(request.app) + async with ConversationHelper(db) as conversation_helper: + user_id = None + if request.get("caller") == "user": + user_id = int(request.get("user")) + conversation_ids = await conversation_helper.filter_user_owned_ids(conversation_ids, user_id=user_id) + + if len(conversation_ids) > 0: + await conversation_helper.remove_multiple(conversation_ids) + + # 通知其他模块删除 + events = EventService.create() + events.emit("conversation/removed", { + "ids": conversation_ids, + "dbs": db, + "app": request.app, + }) + + return await utils.web.api_response(1, data={ + "count": len(conversation_ids) + }, request=request) + + @staticmethod + @utils.web.token_auth + async def set_conversation_pinned(request: web.Request): params = await utils.web.get_param(request, { "id": { "required": True, "type": int + }, + "pinned": { + "required": True, + "type": bool } }) conversation_id = params.get("id") + pinned = params.get("pinned") db = await DatabaseService.create(request.app) async with ConversationHelper(db) as conversation_helper: @@ -203,39 +258,27 @@ class Index: "message": "Permission denied." }, request=request, http_status=403) - await conversation_helper.remove(conversation_info) - - # 通知其他模块删除 - events = EventService.create() - events.emit("conversation/removed", { - "conversation": conversation_info, - "dbs": db, - "app": request.app, - }) - events.emit("conversation/removed/" + conversation_info.module, { - "conversation": conversation_info, - "dbs": db, - "app": request.app, - }) + conversation_info.pinned = pinned + await conversation_helper.update(conversation_info) return await utils.web.api_response(1, request=request) - + @staticmethod @utils.web.token_auth - async def set_conversation_pinned(request: web.Request): + async def set_conversation_title(request: web.Request): params = await utils.web.get_param(request, { "id": { "required": True, "type": int }, - "pinned": { + "new_title": { "required": True, - "type": bool + "type": str } }) conversation_id = params.get("id") - pinned = params.get("pinned") + new_title = params.get("new_title") db = await DatabaseService.create(request.app) async with ConversationHelper(db) as conversation_helper: @@ -253,7 +296,7 @@ class Index: "message": "Permission denied." }, request=request, http_status=403) - conversation_info.pinned = pinned + conversation_info.title = new_title await conversation_helper.update(conversation_info) return await utils.web.api_response(1, request=request) diff --git a/api/model/chat_complete/conversation.py b/api/model/chat_complete/conversation.py index 2bad188..f13c02f 100644 --- a/api/model/chat_complete/conversation.py +++ b/api/model/chat_complete/conversation.py @@ -1,7 +1,8 @@ from __future__ import annotations +import time import sqlalchemy -from sqlalchemy import update +from sqlalchemy import select, update from sqlalchemy.orm import mapped_column, relationship, Mapped from api.model.base import BaseModel @@ -13,10 +14,11 @@ class ConversationChunkModel(BaseModel): __tablename__ = "chat_complete_conversation_chunk" id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) - conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True) + conversation_id: Mapped[int] = mapped_column( + sqlalchemy.ForeignKey(ConversationModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True) message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0) - updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True) + updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) class ConversationChunkHelper: def __init__(self, dbs: DatabaseService): @@ -36,52 +38,58 @@ class ConversationChunkHelper: await self.session.__aexit__(exc_type, exc, tb) pass - async def add(self, conversation_id: int, message_data: list, tokens: int): - async with self.create_session() as session: - chunk = ConversationChunkModel( - conversation_id=conversation_id, - message_data=message_data, - tokens=tokens, - updated_at=sqlalchemy.func.current_timestamp() - ) - session.add(chunk) - await session.commit() - await session.refresh(chunk) - return chunk + async def add(self, obj: ConversationChunkModel): + obj.updated_at = int(time.time()) + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj - async def update(self, chunk: ConversationChunkModel): - chunk.updated_at = sqlalchemy.func.current_timestamp() - chunk = await self.session.merge(chunk) + async def update(self, obj: ConversationChunkModel): + obj.updated_at = int(time.time()) + obj = await self.session.merge(obj) await self.session.commit() - return chunk + return obj async def update_message_log(self, chunk_id: int, message_data: list, tokens: int): stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \ - .values(message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp()) + .values(message_data=message_data, tokens=tokens, updated_at=int(time.time())) await self.session.execute(stmt) await self.session.commit() async def get_newest_chunk(self, conversation_id: int): - stmt = sqlalchemy.select(ConversationChunkModel) \ + stmt = select(ConversationChunkModel) \ .where(ConversationChunkModel.conversation_id == conversation_id) \ .order_by(ConversationChunkModel.id.desc()) \ .limit(1) return await self.session.scalar(stmt) + + async def get_chunk_id_list(self, conversation_id: int): + stmt = select(ConversationChunkModel.id) \ + .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc()) + return await self.session.scalars(stmt) + + async def find_by_id(self, id: int): + stmt = select(ConversationChunkModel).where(ConversationChunkModel.id == id) + return await self.session.scalar(stmt) - async def remove(self, id: int): - stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id) + async def remove(self, id: int | list[int]): + if isinstance(id, list): + stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id.in_(id)) + else: + stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id) await self.session.execute(stmt) await self.session.commit() - async def remove_by_conversation_id(self, conversation_id: int): - stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id == conversation_id) + async def remove_by_conversation_ids(self, ids: list[int]): + stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id.in_(ids)) await self.session.execute(stmt) await self.session.commit() async def on_conversation_removed(event): - if "conversation" in event: - conversation_info = event["conversation"] - conversation_id = conversation_info["id"] - await ConversationChunkHelper(event["dbs"]).remove_by_conversation_id(conversation_id) + if "ids" in event: + conversation_ids = event["ids"] + async with ConversationChunkHelper(event["dbs"]) as chunk_helper: + await chunk_helper.remove_by_conversation_ids(conversation_ids) -EventService.create().add_listener("conversation/removed/chatcomplete", on_conversation_removed) \ No newline at end of file +EventService.create().add_listener("conversation/removed", on_conversation_removed) \ No newline at end of file diff --git a/api/model/toolkit_ui/conversation.py b/api/model/toolkit_ui/conversation.py index 4dfdc54..9c262a2 100644 --- a/api/model/toolkit_ui/conversation.py +++ b/api/model/toolkit_ui/conversation.py @@ -1,4 +1,5 @@ from __future__ import annotations +import time from typing import List, Optional import sqlalchemy @@ -17,11 +18,11 @@ class ConversationModel(BaseModel): module: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True) thumbnail: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True) + description: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True) page_id: Mapped[int] = mapped_column( sqlalchemy.Integer, index=True, nullable=True) rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) - updated_at: Mapped[int] = mapped_column( - sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now()) + updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) pinned: Mapped[bool] = mapped_column( sqlalchemy.Boolean, default=False, index=True) extra: Mapped[dict] = mapped_column(sqlalchemy.JSON, default={}) @@ -46,13 +47,8 @@ class ConversationHelper: await self.session.__aexit__(exc_type, exc, tb) pass - async def add(self, user_id: int, module: str, title: Optional[str] = None, page_id: Optional[int] = None, rev_id: Optional[int] = None, extra: Optional[dict] = None): - obj = ConversationModel(user_id=user_id, module=module, title=title, - page_id=page_id, rev_id=rev_id, updated_at=sqlalchemy.func.current_timestamp()) - - if extra is not None: - obj.extra = extra - + async def add(self, obj: ConversationModel): + obj.updated_at = int(time.time()) self.session.add(obj) await self.session.commit() await self.session.refresh(obj) @@ -60,11 +56,12 @@ class ConversationHelper: async def refresh_updated_at(self, conversation_id: int): stmt = update(ConversationModel).where(ConversationModel.id == - conversation_id).values(updated_at=sqlalchemy.func.current_timestamp()) + conversation_id).values(updated_at=int(time.time())) await self.session.execute(stmt) await self.session.commit() async def update(self, obj: ConversationModel): + obj.updated_at = int(time.time()) await self.session.merge(obj) await self.session.commit() await self.session.refresh(obj) @@ -85,14 +82,25 @@ class ConversationHelper: return await self.session.scalars(stmt) - async def find_by_id(self, conversation_id: int): - async with self.create_session() as session: - stmt = sqlalchemy.select(ConversationModel).where( - ConversationModel.id == conversation_id) - return await session.scalar(stmt) + async def find_by_id(self, id: int): + stmt = sqlalchemy.select(ConversationModel).where( + ConversationModel.id == id) + return await self.session.scalar(stmt) + + async def filter_user_owned_ids(self, ids: list[int], user_id: int) -> list[int]: + stmt = sqlalchemy.select(ConversationModel.id) \ + .where(ConversationModel.id.in_(ids)).where(ConversationModel.user_id == user_id) + return list(await self.session.scalars(stmt)) async def remove(self, conversation_id: int): stmt = sqlalchemy.delete(ConversationModel).where( ConversationModel.id == conversation_id) await self.session.execute(stmt) + await self.session.commit() + + async def remove_multiple(self, ids: list[int]): + stmt = sqlalchemy.delete(ConversationModel) \ + .where(ConversationModel.id.in_(ids)) + + await self.session.execute(stmt) await self.session.commit() \ No newline at end of file diff --git a/api/model/toolkit_ui/page_title.py b/api/model/toolkit_ui/page_title.py index 3ef0c3b..984b4e4 100644 --- a/api/model/toolkit_ui/page_title.py +++ b/api/model/toolkit_ui/page_title.py @@ -1,5 +1,5 @@ from __future__ import annotations -import datetime +import time from typing import Optional import sqlalchemy @@ -17,8 +17,7 @@ class PageTitleModel(BaseModel): sqlalchemy.Integer, primary_key=True, autoincrement=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True) - updated_at: Mapped[int] = mapped_column( - sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now()) + updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) class PageTitleHelper: @@ -58,11 +57,11 @@ class PageTitleHelper: title_info = await self.find_by_title(title) if title_info is None: return True - if title_info.updated_at < (datetime.now() - datetime.timedelta(days=7)): + if time.time() - title_info.updated_at > 60: return True async def add(self, page_id: int, title: Optional[str] = None): - obj = PageTitleModel(page_id=page_id, title=title, updated_at=sqlalchemy.func.current_timestamp()) + obj = PageTitleModel(page_id=page_id, title=title, updated_at=int(time.time())) self.session.add(obj) await self.session.commit() @@ -71,14 +70,14 @@ class PageTitleHelper: async def set_title(self, page_id: int, title: Optional[str] = None): stmt = update(PageTitleModel).where( - PageTitleModel.page_id == page_id).values(title=title, updated_at=sqlalchemy.func.current_timestamp()) + PageTitleModel.page_id == page_id).values(title=title, updated_at=int(time.time())) await self.session.execute(stmt) await self.session.commit() async def update(self, obj: PageTitleModel, ignore_updated_at: bool = False): - self.session.merge(obj) + await self.session.merge(obj) if not ignore_updated_at: - obj.updated_at = sqlalchemy.func.current_timestamp() + obj.updated_at = int(time.time()) await self.session.commit() await self.session.refresh(obj) return obj diff --git a/api/route.py b/api/route.py index e7d0bd0..07565ae 100644 --- a/api/route.py +++ b/api/route.py @@ -18,17 +18,19 @@ def init(app: web.Application): web.route('*', '/title/info', Index.update_title_info), web.route('*', '/user/info', Index.get_user_info), - web.route('*', '/conversations', Index.get_conversation_list), + web.route('*', '/conversation/list', Index.get_conversation_list), web.route('*', '/conversation/info', Index.get_conversation_info), web.route('POST', '/conversation/remove', Index.remove_conversation), web.route('DELETE', '/conversation/remove', Index.remove_conversation), web.route('POST', '/conversation/set_pinned', Index.set_conversation_pinned), + web.route('POST', '/conversation/set_title', Index.set_conversation_title), web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page), web.route('*', '/embedding_search/search', EmbeddingSearch.search), - web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list), - web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk), - web.route('*', '/chatcomplete/message', ChatComplete.start_chat_complete), - web.route('*', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), + web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list), + web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk), + web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete), + web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), + web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost), ]) diff --git a/config-example.py b/config-example.py index 06ea447..6aa04c6 100644 --- a/config-example.py +++ b/config-example.py @@ -39,8 +39,6 @@ CHATCOMPLETE_OUTPUT_REPLACE = { "人工智能程式": "虛擬人物程序", } -CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE = "无标题" - CHATCOMPLETE_BOT_NAME = "寫作助手" PROMPTS = { diff --git a/main.py b/main.py index d9585ce..0cc4471 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ from api.model.embedding_search.title_index import TitleIndexModel as _ from service.tiktoken import TikTokenService async def index(request: web.Request): - return utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request) + return await utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request) async def init_mw_api(app: web.Application): mw_api = MediaWikiApi.create() diff --git a/noawait.py b/noawait.py index ce7eca5..e0e6c20 100644 --- a/noawait.py +++ b/noawait.py @@ -1,34 +1,98 @@ from __future__ import annotations from asyncio import AbstractEventLoop, Task import asyncio +import atexit from functools import wraps +import random import sys import traceback -from typing import Callable, Coroutine +from typing import Callable, Coroutine, Optional, TypedDict + +class TimerInfo(TypedDict): + id: int + callback: Callable + interval: float + next_time: float class NoAwaitPool: def __init__(self, loop: AbstractEventLoop): self.task_list: list[Task] = [] + self.timer_map: dict[int, TimerInfo] = {} self.loop = loop self.running = True + self.should_refresh_task = False + self.next_timer_time: Optional[float] = None + self.on_error: list[Callable] = [] self.gc_task = loop.create_task(self._run_gc()) + self.timer_task = loop.create_task(self._run_timer()) + + atexit.register(self.end_task) async def end(self): - print("Stopping NoAwait Tasks...") - self.running = False - for task in self.task_list: - await self._finish_task(task) - - await self.gc_task + if self.running: + print("Stopping NoAwait Tasks...") + self.running = False + for task in self.task_list: + await self._finish_task(task) + + await self.gc_task + await self.timer_task + + def end_task(self): + if self.running and not self.loop.is_closed(): + self.loop.run_until_complete(self.end()) + async def _wrap_task(self, task: Task): + try: + await task + except Exception as e: + handled = False + for handler in self.on_error: + try: + handler_ret = handler(e) + await handler_ret + handled = True + except Exception as handler_err: + print("Exception on error handler: " + str(handler_err), file=sys.stderr) + traceback.print_exc() + + if not handled: + print(e, file=sys.stderr) + traceback.print_exc() + finally: + self.should_refresh_task = True + def add_task(self, coroutine: Coroutine): task = self.loop.create_task(coroutine) self.task_list.append(task) + def add_timer(self, callback: Callable, interval: float) -> int: + id = random.randint(0, 1000000000) + while id in self.timer_map: + id = random.randint(0, 1000000000) + + now = self.loop.time() + next_time = now + interval + self.timer_map[id] = { + "id": id, + "callback": callback, + "interval": interval, + "next_time": next_time + } + + if self.next_timer_time is None or next_time < self.next_timer_time: + self.next_timer_time = next_time + + return id + + def remove_timer(self, id: int): + if id in self.timer_map: + del self.timer_map[id] + def wrap(self, f): @wraps(f) def decorated_function(*args, **kwargs): @@ -47,8 +111,7 @@ class NoAwaitPool: for handler in self.on_error: try: handler_ret = handler(e) - if handler_ret is Coroutine: - await handler_ret + await handler_ret handled = True except Exception as handler_err: print("Exception on error handler: " + str(handler_err), file=sys.stderr) @@ -57,16 +120,46 @@ class NoAwaitPool: if not handled: print(e, file=sys.stderr) traceback.print_exc() - async def _run_gc(self): while self.running: - should_remove = [] - for task in self.task_list: - if task.done(): - await self._finish_task(task) - should_remove.append(task) - for task in should_remove: - self.task_list.remove(task) + if self.should_refresh_task: + should_remove = [] + for task in self.task_list: + if task.done(): + await self._finish_task(task) + should_remove.append(task) + for task in should_remove: + self.task_list.remove(task) + + await asyncio.sleep(0.1) + + async def _run_timer(self): + while self.running: + now = self.loop.time() + if self.next_timer_time is not None and now >= self.next_timer_time: + self.next_timer_time = None + for timer in self.timer_map.values(): + if now >= timer["next_time"]: + timer["next_time"] = now + timer["interval"] + try: + result = timer["callback"]() + self.add_task(result) + except Exception as e: + handled = False + for handler in self.on_error: + try: + handler_ret = handler(e) + self.add_task(handler_ret) + handled = True + except Exception as handler_err: + print("Exception on error handler: " + str(handler_err), file=sys.stderr) + traceback.print_exc() + if not handled: + print(e, file=sys.stderr) + traceback.print_exc() + if self.next_timer_time is None or timer["next_time"] < self.next_timer_time: + self.next_timer_time = timer["next_time"] + await asyncio.sleep(0.1) \ No newline at end of file diff --git a/requirements-embedding.txt b/requirements-embedding.txt new file mode 100644 index 0000000..3958881 --- /dev/null +++ b/requirements-embedding.txt @@ -0,0 +1,5 @@ +transformers +--index-url https://download.pytorch.org/whl/cpu +torch +torchvision +torchaudio \ No newline at end of file diff --git a/service/bert_embedding.py b/service/bert_embedding.py new file mode 100644 index 0000000..56450a4 --- /dev/null +++ b/service/bert_embedding.py @@ -0,0 +1,143 @@ +from __future__ import annotations +import time +import config +import asyncio +import random +import threading +from typing import Optional, TypedDict +import torch +from transformers import pipeline +from local import loop + +from service.tiktoken import TikTokenService + +BERT_EMBEDDING_QUEUE_TIMEOUT = 1 + +class BERTEmbeddingQueueTaskInfo(TypedDict): + task_id: int + text: str + embedding: torch.Tensor + +class BERTEmbeddingQueue: + def init(self): + self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese") + self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {} + self.task_list: list[BERTEmbeddingQueueTaskInfo] = [] + self.lock = threading.Lock() + + self.thread: Optional[threading.Thread] = None + self.running = False + + async def get_embeddings(self, text: str): + text = "[CLS]" + text + "[SEP]" + task_id = random.randint(0, 1000000000) + with self.lock: + while task_id in self.task_map: + task_id = random.randint(0, 1000000000) + + task_info = { + "task_id": task_id, + "text": text, + "embedding": None + } + self.task_map[task_id] = task_info + self.task_list.append(task_info) + + self.start_queue() + while True: + task_info = self.pop_task(task_id) + if task_info is not None: + return task_info["embedding"] + + await asyncio.sleep(0.01) + + def pop_task(self, task_id): + with self.lock: + if task_id in self.task_map: + task_info = self.task_map[task_id] + if task_info["embedding"] is not None: + del self.task_map[task_id] + return task_info + + return None + + def run(self): + running = True + last_task_time = None + while running and self.running: + current_time = time.time() + task = None + with self.lock: + if len(self.task_list) > 0: + task = self.task_list.pop(0) + + if task is not None: + embeddings = self.embedding_model(task["text"]) + + with self.lock: + task["embedding"] = embeddings[0][1] + + last_task_time = time.time() + elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: + self.thread = None + self.running = False + running = False + else: + time.sleep(0.01) + + def start_queue(self): + if not self.running: + self.running = True + self.thread = threading.Thread(target=self.run) + self.thread.start() + +bert_embedding_queue = BERTEmbeddingQueue() +bert_embedding_queue.init() + +class BERTEmbeddingService: + instance = None + + @staticmethod + async def create() -> BERTEmbeddingService: + if BERTEmbeddingService.instance is None: + BERTEmbeddingService.instance = BERTEmbeddingService() + await BERTEmbeddingService.instance.init() + return BERTEmbeddingService.instance + + async def init(self): + self.tiktoken = await TikTokenService.create() + self.embedding_queue = BERTEmbeddingQueue() + await loop.run_in_executor(None, self.embedding_queue.init) + + async def get_embeddings(self, docs, on_progress=None): + if len(docs) == 0: + return ([], 0) + + if on_progress is not None: + await on_progress(0, len(docs)) + + embeddings = [] + token_usage = 0 + + for doc in docs: + if "text" in doc: + tokens = await self.tiktoken.get_tokens(doc["text"]) + token_usage += tokens + embeddings.append({ + "id": doc["id"], + "text": doc["text"], + "embedding": self.model.encode(doc["text"]), + "tokens": tokens + }) + else: + embeddings.append({ + "id": doc["id"], + "text": doc["text"], + "embedding": None, + "tokens": 0 + }) + + if on_progress is not None: + await on_progress(1, len(docs)) + + return (embeddings, token_usage) \ No newline at end of file diff --git a/service/chat_complete.py b/service/chat_complete.py index e2d6e54..f634536 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -1,12 +1,18 @@ from __future__ import annotations +import time import traceback from typing import Optional, Tuple, TypedDict -from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationChunkModel + +import sqlalchemy +from api.model.chat_complete.conversation import ( + ConversationChunkHelper, + ConversationChunkModel, +) import sys from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel import config -import utils.config +import utils.config, utils.web from aiohttp import web from api.model.embedding_search.title_collection import TitleCollectionModel @@ -18,21 +24,21 @@ from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService - class ChatCompleteServicePrepareResponse(TypedDict): extract_doc: list question_tokens: int - + conversation_id: int + chunk_id: int class ChatCompleteServiceResponse(TypedDict): message: str message_tokens: int total_tokens: int finish_reason: str - conversation_id: int + question_message_id: str + response_message_id: str delta_data: dict - class ChatCompleteService: def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs @@ -58,6 +64,7 @@ class ChatCompleteService: self.question = "" self.question_tokens: Optional[int] = None self.conversation_id: Optional[int] = None + self.conversation_start_time: Optional[int] = None self.delta_data = {} @@ -81,22 +88,31 @@ class ChatCompleteService: async def get_question_tokens(self, question: str): return await self.tiktoken.get_tokens(question) - async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, - question_tokens: Optional[int] = None, - embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse: + async def prepare_chat_complete( + self, + question: str, + conversation_id: Optional[str] = None, + user_id: Optional[int] = None, + question_tokens: Optional[int] = None, + edit_message_id: Optional[str] = None, + embedding_search: Optional[EmbeddingSearchArgs] = None, + ) -> ChatCompleteServicePrepareResponse: if user_id is not None: user_id = int(user_id) self.user_id = user_id self.question = question + self.conversation_start_time = int(time.time()) self.conversation_info = None if conversation_id is not None: self.conversation_id = int(conversation_id) - self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id) + self.conversation_info = await self.conversation_helper.find_by_id( + self.conversation_id + ) else: self.conversation_id = None - + if self.conversation_info is not None: if self.conversation_info.user_id != user_id: raise web.HTTPUnauthorized() @@ -106,97 +122,201 @@ class ChatCompleteService: else: self.question_tokens = question_tokens - if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and - self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS): + if ( + len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS + and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS + ): # If the question is too long, we need to truncate it raise web.HTTPRequestEntityTooLarge() + self.conversation_chunk = None + if self.conversation_info is not None: + chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id) + + if edit_message_id and "," in edit_message_id: + (edit_chunk_id, edit_msg_id) = edit_message_id.split(",") + edit_chunk_id = int(edit_chunk_id) + + # Remove overrided conversation chunks + start_overrided = False + should_remove_chunk_ids = [] + for chunk_id in chunk_id_list: + if start_overrided: + should_remove_chunk_ids.append(chunk_id) + else: + if chunk_id == edit_chunk_id: + start_overrided = True + + if len(should_remove_chunk_ids) > 0: + await self.conversation_chunk_helper.remove(should_remove_chunk_ids) + + self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk( + self.conversation_id + ) + # Remove outdated message + edit_message_pos = None + old_tokens = 0 + for i in range(0, len(self.conversation_chunk.message_data)): + msg_data = self.conversation_chunk.message_data[i] + if msg_data["id"] == edit_msg_id: + edit_message_pos = i + break + if "tokens" in msg_data and msg_data["tokens"]: + old_tokens += msg_data["tokens"] + + if edit_message_pos: + self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos] + flag_modified(self.conversation_chunk, "message_data") + self.conversation_chunk.tokens = old_tokens + else: + self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk( + self.conversation_id + ) + + # If the conversation is too long, we need to make a summary + if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: + summary, tokens = await self.make_summary( + self.conversation_chunk.message_data + ) + new_message_log = [ + { + "role": "summary", + "content": summary, + "tokens": tokens, + "time": int(time.time()), + } + ] + + self.conversation_chunk = ConversationChunkModel( + conversation_id=self.conversation_id, + message_data=new_message_log, + tokens=tokens, + ) + + self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk) + else: + # 创建新对话 + title_info = self.embedding_search.title_info + self.conversation_info = ConversationModel( + user_id=self.user_id, + module="chatcomplete", + page_id=title_info["page_id"], + rev_id=title_info["rev_id"], + ) + self.conversation_info = await self.conversation_helper.add( + self.conversation_info, + ) + + self.conversation_chunk = ConversationChunkModel( + conversation_id=self.conversation_info.id, + message_data=[], + tokens=0, + ) + self.conversation_chunk = await self.conversation_chunk_helper.add( + self.conversation_chunk + ) + # Extract document from wiki page index self.extract_doc = None if embedding_search is not None: - self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search) + self.extract_doc, token_usage = await self.embedding_search.search( + question, **embedding_search + ) if self.extract_doc is not None: self.question_tokens += token_usage return ChatCompleteServicePrepareResponse( extract_doc=self.extract_doc, - question_tokens=self.question_tokens + question_tokens=self.question_tokens, + conversation_id=self.conversation_info.id, + chunk_id=self.conversation_chunk.id ) - async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse: + async def finish_chat_complete( + self, on_message: Optional[callable] = None + ) -> ChatCompleteServiceResponse: delta_data = {} - self.conversation_chunk = None message_log = [] - if self.conversation_info is not None: - self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id) - - # If the conversation is too long, we need to make a summary - if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: - summary, tokens = await self.make_summary(self.conversation_chunk.message_data) - new_message_log = [ - {"role": "summary", "content": summary, "tokens": tokens} - ] - - self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens) - - self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id - - message_log = [] + if self.conversation_chunk is not None: for message in self.conversation_chunk.message_data: - message_log.append({ - "role": message["role"], - "content": message["content"], - }) + message_log.append( + { + "role": message["role"], + "content": message["content"], + } + ) if self.extract_doc is not None: - doc_prompt_content = "\n".join(["%d. %s" % ( - i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)]) + doc_prompt_content = "\n".join( + [ + "%d. %s" % (i + 1, doc["markdown"] or doc["text"]) + for i, doc in enumerate(self.extract_doc) + ] + ) - doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", { - "content": doc_prompt_content}) + doc_prompt = utils.config.get_prompt( + "extracted_doc", "prompt", {"content": doc_prompt_content} + ) message_log.append({"role": "user", "content": doc_prompt}) system_prompt = utils.config.get_prompt("chat", "system_prompt") # Start chat complete if on_message is not None: - response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message) + response = await self.openai_api.chat_complete_stream( + self.question, system_prompt, message_log, on_message + ) else: - response = await self.openai_api.chat_complete(self.question, system_prompt, message_log) - - if self.conversation_info is None: - # Create a new conversation - message_log_list = [ - {"role": "user", "content": self.question, "tokens": self.question_tokens}, - {"role": "assistant", - "content": response["message"], "tokens": response["message_tokens"]}, - ] - title = None - try: - title, token_usage = await self.make_title(message_log_list) - delta_data["title"] = title - except Exception as e: - title = config.CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE - print(str(e), file=sys.stderr) - traceback.print_exc(file=sys.stderr) + response = await self.openai_api.chat_complete( + self.question, system_prompt, message_log + ) + + description = response["message"][0:150] + + question_msg_id = utils.web.generate_uuid() + response_msg_id = utils.web.generate_uuid() + + new_message_data = [ + { + "id": question_msg_id, + "role": "user", + "content": self.question, + "tokens": self.question_tokens, + "time": self.conversation_start_time, + }, + { + "id": response_msg_id, + "role": "assistant", + "content": response["message"], + "tokens": response["message_tokens"], + "time": int(time.time()), + }, + ] + if self.conversation_info is not None: total_token_usage = self.question_tokens + response["message_tokens"] - - title_info = self.embedding_search.title_info - self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title) - self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage) - else: - # Update the conversation chunk - await self.conversation_helper.refresh_updated_at(self.conversation_id) - - self.conversation_chunk.message_data.append( - {"role": "user", "content": self.question, "tokens": self.question_tokens}) - self.conversation_chunk.message_data.append( - {"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]}) + # Generate title if not exists + if self.conversation_info.title is None: + title = None + try: + title, token_usage = await self.make_title(new_message_data) + delta_data["title"] = title + except Exception as e: + print(str(e), file=sys.stderr) + traceback.print_exc(file=sys.stderr) + + self.conversation_info.title = title + + # Update conversation info + self.conversation_info.description = description + + await self.conversation_helper.update(self.conversation_info) + + # Update conversation chunk + self.conversation_chunk.message_data.extend(new_message_data) flag_modified(self.conversation_chunk, "message_data") - self.conversation_chunk.tokens += self.question_tokens + \ - response["message_tokens"] + self.conversation_chunk.tokens += total_token_usage await self.conversation_chunk_helper.update(self.conversation_chunk) @@ -205,17 +325,18 @@ class ChatCompleteService: message_tokens=response["message_tokens"], total_tokens=response["total_tokens"], finish_reason=response["finish_reason"], - conversation_id=self.conversation_info.id, - delta_data=delta_data + question_message_id=question_msg_id, + response_message_id=response_msg_id, + delta_data=delta_data, ) - + async def set_latest_point_cost(self, point_cost: int) -> bool: if self.conversation_chunk is None: return False if len(self.conversation_chunk.message_data) == 0: return False - + for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): if self.conversation_chunk.message_data[i]["role"] == "assistant": self.conversation_chunk.message_data[i]["point_cost"] = point_cost @@ -224,44 +345,50 @@ class ChatCompleteService: return True - async def make_summary(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] for message_data in message_log_list: - if message_data["role"] == 'summary': + if message_data["role"] == "summary": chat_log.append(message_data["content"]) - elif message_data["role"] == 'assistant': + elif message_data["role"] == "assistant": chat_log.append( - f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}') + f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' + ) else: chat_log.append(f'User: {message_data["content"]}') - chat_log_str = '\n'.join(chat_log) + chat_log_str = "\n".join(chat_log) - summary_system_prompt = utils.config.get_prompt( - "summary", "system_prompt") + summary_system_prompt = utils.config.get_prompt("summary", "system_prompt") summary_prompt = utils.config.get_prompt( - "summary", "prompt", {"content": chat_log_str}) + "summary", "prompt", {"content": chat_log_str} + ) - response = await self.openai_api.chat_complete(summary_prompt, summary_system_prompt) + response = await self.openai_api.chat_complete( + summary_prompt, summary_system_prompt + ) return response["message"], response["message_tokens"] async def make_title(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] for message_data in message_log_list: - if message_data["role"] == 'assistant': + if message_data["role"] == "assistant": chat_log.append( - f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}') - elif message_data["role"] == 'user': + f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' + ) + elif message_data["role"] == "user": chat_log.append(f'User: {message_data["content"]}') - chat_log_str = '\n'.join(chat_log) + chat_log_str = "\n".join(chat_log) title_system_prompt = utils.config.get_prompt("title", "system_prompt") title_prompt = utils.config.get_prompt( - "title", "prompt", {"content": chat_log_str}) + "title", "prompt", {"content": chat_log_str} + ) - response = await self.openai_api.chat_complete(title_prompt, title_system_prompt) + response = await self.openai_api.chat_complete( + title_prompt, title_system_prompt + ) return response["message"], response["message_tokens"] diff --git a/service/embedding_search.py b/service/embedding_search.py index 7a977de..fa50ef5 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -9,12 +9,17 @@ from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService from utils.wiki import getWikiSentences +class EmbeddingRunningException(Exception): + pass + class EmbeddingSearchArgs(TypedDict): limit: Optional[int] in_collection: Optional[bool] distance_limit: Optional[float] class EmbeddingSearchService: + indexing_page_ids: list[int] = [] + def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs @@ -92,6 +97,9 @@ class EmbeddingSearchService: self.page_id = self.page_info["pageid"] + if self.page_id in self.indexing_page_ids: + raise EmbeddingRunningException("Page index is running now") + # Create collection self.collection_info = await self.title_collection.find_by_title(self.base_title) if self.collection_info is None: @@ -129,7 +137,6 @@ class EmbeddingSearchService: if self.unindexed_docs is None: return False - chunk_limit = 500 chunk_len = 0 diff --git a/service/mediawiki_api.py b/service/mediawiki_api.py index 2238605..4cbbbc0 100644 --- a/service/mediawiki_api.py +++ b/service/mediawiki_api.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json import sys import time @@ -15,24 +16,49 @@ class MediaWikiApiException(Exception): def __str__(self) -> str: return self.info -class MediaWikiPageNotFoundException(MediaWikiApiException): - pass +class MediaWikiPageNotFoundException(Exception): + def __init__(self, info: str, code: Optional[str] = None) -> None: + super().__init__(info) + self.info = info + self.code = code + self.message = self.info + + def __str__(self) -> str: + return self.info + +class MediaWikiUserNoEnoughPointsException(Exception): + def __init__(self, info: str, code: Optional[str] = None) -> None: + super().__init__(info) + self.info = info + self.code = code + self.message = self.info + + def __str__(self) -> str: + return self.info + +class ChatCompleteGetPointUsageResponse(TypedDict): + point_cost: int class ChatCompleteReportUsageResponse(TypedDict): point_cost: int transaction_id: str class MediaWikiApi: - cookie_jar = aiohttp.CookieJar(unsafe=True) + instance: MediaWikiApi = None @staticmethod def create(): - return MediaWikiApi(config.MW_API) + if MediaWikiApi.instance is None: + MediaWikiApi.instance = MediaWikiApi(config.MW_API) + + return MediaWikiApi.instance def __init__(self, api_url: str): self.api_url = api_url - self.login_time = 0.0 + + self.cookie_jar = aiohttp.CookieJar(unsafe=True) self.login_identity = None + self.login_time = 0.0 async def get_page_info(self, title: str): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: @@ -50,7 +76,7 @@ class MediaWikiApi: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) if "missing" in data["query"]["pages"][0]: - raise MediaWikiPageNotFoundException() + raise MediaWikiPageNotFoundException(data["error"]["info"], data["error"]["code"]) return data["query"]["pages"][0] @@ -98,7 +124,21 @@ class MediaWikiApi: ret["user"] = data["query"]["userinfo"]["name"] return ret + + async def is_logged_in(self,): + async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: + params = { + "action": "query", + "format": "json", + "formatversion": "2", + "meta": "userinfo" + } + async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + data = await resp.json() + if "error" in data: + raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) + return data["query"]["userinfo"]["id"] != 0 async def get_token(self, token_type: str): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: @@ -145,8 +185,10 @@ class MediaWikiApi: async def refresh_login(self): if self.login_identity is None: + print("刷新MW机器人账号登录状态失败:没有保存的用户") return False - if time.time() - self.login_time > 30: + if time.time() - self.login_time > 3600: + print("刷新MW机器人账号登录状态") return await self.robot_login(self.login_identity["username"], self.login_identity["password"]) async def chat_complete_user_info(self, user_id: int): @@ -169,6 +211,32 @@ class MediaWikiApi: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) return data["chatcompletebot"]["userinfo"] + + async def chat_complete_get_point_cost(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteGetPointUsageResponse: + await self.refresh_login() + + async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: + post_data = { + "action": "chatcompletebot", + "method": "reportusage", + "step": "check", + "userid": int(user_id) if user_id is not None else None, + "useraction": user_action, + "tokens": int(tokens) if tokens is not None else None, + "extractlines": int(extractlines) if extractlines is not None else None, + "format": "json", + "formatversion": "2", + } + # Filter out None values + post_data = {k: v for k, v in post_data.items() if v is not None} + async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + data = await resp.json() + if "error" in data: + print(data) + raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) + + point_cost = int(data["chatcompletebot"]["reportusage"]["pointcost"] or 0) + return ChatCompleteGetPointUsageResponse(point_cost=point_cost) async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse: await self.refresh_login() @@ -178,10 +246,10 @@ class MediaWikiApi: "action": "chatcompletebot", "method": "reportusage", "step": "start", - "userid": int(user_id), + "userid": int(user_id) if user_id is not None else None, "useraction": user_action, - "tokens": int(tokens), - "extractlines": int(extractlines), + "tokens": int(tokens) if tokens is not None else None, + "extractlines": int(extractlines) if extractlines is not None else None, "format": "json", "formatversion": "2", } @@ -190,10 +258,13 @@ class MediaWikiApi: async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: data = await resp.json() if "error" in data: - print(data) - raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) + if data["error"]["code"] == "noenoughpoints": + raise MediaWikiUserNoEnoughPointsException(data["error"]["info"], data["error"]["info"]) + else: + raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) - return ChatCompleteReportUsageResponse(point_cost=int(data["chatcompletebot"]["reportusage"]["pointcost"]), + point_cost = int(data["chatcompletebot"]["reportusage"]["pointcost"] or 0) + return ChatCompleteReportUsageResponse(point_cost=point_cost, transaction_id=data["chatcompletebot"]["reportusage"]["transactionid"]) async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None): @@ -237,7 +308,7 @@ class MediaWikiApi: } # Filter out None values post_data = {k: v for k, v in post_data.items() if v is not None} - async with session.get(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) diff --git a/service/openai_api.py b/service/openai_api.py index 29c82f6..8d62088 100644 --- a/service/openai_api.py +++ b/service/openai_api.py @@ -1,4 +1,5 @@ from __future__ import annotations +import asyncio import json from typing import Callable, Optional, TypedDict @@ -36,20 +37,20 @@ class OpenAIApi: def build_header(self): if config.OPENAI_API_TYPE == "azure": return { - "Content-Type": "application/json", - "Accept": "application/json", + "content-type": "application/json", + "accept": "application/json", "api-key": self.api_key } else: return { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "Accept": "application/json", + "authorization": f"Bearer {self.api_key}", + "content-type": "application/json", + "accept": "application/json", } def get_url(self, method: str): if config.OPENAI_API_TYPE == "azure": - if method == "completions": + if method == "chat/completions": return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method elif method == "embeddings": return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method @@ -92,26 +93,41 @@ class OpenAIApi: if config.OPENAI_API_TYPE == "azure": # Azure api does not support batch for index, text in enumerate(text_list): - async with session.post(url, - headers=self.build_header(), - params=params, - json={"input": text}, - timeout=30, - proxy=config.REQUEST_PROXY) as resp: - - data = await resp.json() - - one_data = data["data"] - if len(one_data) > 0: - embedding = one_data[0]["embedding"] - if embedding is not None: - embedding = np.array(embedding) - doc_list[index]["embedding"] = embedding - - token_usage += int(data["usage"]["total_tokens"]) - - if on_index_progress is not None: - await on_index_progress(index, len(text_list)) + retry_num = 0 + max_retry_num = 3 + while retry_num < max_retry_num: + try: + async with session.post(url, + headers=self.build_header(), + params=params, + json={"input": text}, + timeout=30, + proxy=config.REQUEST_PROXY) as resp: + + data = await resp.json() + + one_data = data["data"] + if len(one_data) > 0: + embedding = one_data[0]["embedding"] + if embedding is not None: + embedding = np.array(embedding) + doc_list[index]["embedding"] = embedding + + token_usage += int(data["usage"]["total_tokens"]) + + if on_index_progress is not None: + await on_index_progress(index, len(text_list)) + + break + except Exception as e: + retry_num += 1 + + if retry_num >= max_retry_num: + raise e + + print("Error: %s" % e) + print("Retrying...") + await asyncio.sleep(0.5) else: async with session.post(url, headers=self.build_header(), @@ -158,7 +174,7 @@ class OpenAIApi: async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None): messageList = await self.make_message_list(question, system_prompt, conversation) - url = self.get_url("completions") + url = self.get_url("chat/completions") params = {} post_data = { @@ -175,7 +191,7 @@ class OpenAIApi: async with aiohttp.ClientSession() as session: async with session.post(url, - headers=self.build_header, + headers=self.build_header(), params=params, json=post_data, timeout=30, @@ -210,24 +226,38 @@ class OpenAIApi: for message in messageList: prompt_tokens += await tiktoken.get_tokens(message["content"]) - params = { - "model": "gpt-3.5-turbo", + url = self.get_url("chat/completions") + + params = {} + post_data = { "messages": messageList, - "stream": True, "user": user, + "stream": True, + "n": 1, + "max_tokens": 768, + "stop": None, + "temperature": 1, + "top_p": 0.95 } - params = {k: v for k, v in params.items() if v is not None} + + if config.OPENAI_API_TYPE == "azure": + params["api-version"] = "2023-05-15" + else: + post_data["model"] = "gpt-3.5-turbo" + + post_data = {k: v for k, v in post_data.items() if v is not None} res_message: list[str] = [] finish_reason = None async with sse_client.EventSource( - self.api_url + "/v1/chat/completions", + url, option={ "method": "POST" }, - headers={"Authorization": f"Bearer {self.api_key}"}, - json=params, + headers=self.build_header(), + params=params, + json=post_data, proxy=config.REQUEST_PROXY ) as session: async for event in session: @@ -238,11 +268,13 @@ class OpenAIApi: [DONE] """ content_started = False - - if event.data == "[DONE]": + + event_data = event.data.strip() + + if event_data == "[DONE]": break - elif event.data[0] == "{" and event.data[-1] == "}": - data = json.loads(event.data) + elif event_data[0] == "{" and event_data[-1] == "}": + data = json.loads(event_data) if "choices" in data and len(data["choices"]) > 0: choice = data["choices"][0] @@ -262,11 +294,14 @@ class OpenAIApi: res_message.append(delta_message) - if config.DEBUG: - print(delta_message, end="", flush=True) + # if config.DEBUG: + # print(delta_message, end="", flush=True) if on_message is not None: await on_message(delta_message) + + if finish_reason is not None: + break res_message_str = "".join(res_message) message_tokens = await tiktoken.get_tokens(res_message_str) diff --git a/simple_query_builder.py b/simple_query_builder.py deleted file mode 100644 index 7677a04..0000000 --- a/simple_query_builder.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations -import asyncpg - -class SimpleQueryBuilder: - def __init__(self): - self._table_name = "" - self._select = ["*"] - self._where = [] - self._having = [] - self._order_by = None - self._order_by_desc = False - - def table(self, table_name: str): - self._table_name = table_name - return self - - def fields(self, fields: list[str]): - self.select = fields - return self - - def where(self, where: str, condition: str, param): - self._where.append((where, condition, param)) - return self - - def having(self, having: str, condition: str, param): - self._having.append((having, condition, param)) - return self - - def build(self): - sql = "SELECT %s FROM %s" % (", ".join(self._select), self._table_name) - - params = [] - paramsLen = 0 - - if len(self._where) > 0: - sql += " WHERE " - for where, condition, param in self._where: - params.append(param) - paramsLen += 1 - sql += "%s %s $%d AND " % (where, condition, paramsLen) - - if self._order_by is not None: - sql += " ORDER BY %s %s" % (self._order_by, "DESC" if self._order_by_desc else "ASC") - - if len(self._having) > 0: - sql += " HAVING " - for having, condition, param in self._having: - params.append(param) - paramsLen += 1 - sql += "%s %s $%d AND " % (having, condition, paramsLen) \ No newline at end of file diff --git a/test/base.py b/test/base.py new file mode 100644 index 0000000..bb81f7d --- /dev/null +++ b/test/base.py @@ -0,0 +1,4 @@ +import sys +import pathlib + +sys.path.append(str(pathlib.Path(__file__).parent.parent)) \ No newline at end of file diff --git a/test/bert_embedding_queue.py b/test/bert_embedding_queue.py new file mode 100644 index 0000000..a5b0cd2 --- /dev/null +++ b/test/bert_embedding_queue.py @@ -0,0 +1,27 @@ +import asyncio +import time +import base +from local import loop, noawait +from service.bert_embedding import bert_embedding_queue + +async def main(): + embedding_list = [] + start_time = time.time() + queue = [] + with open("test/test.md", "r", encoding="utf-8") as fp: + text = fp.read() + lines = text.split("\n") + for line in lines: + line = line.strip() + if line == "": + continue + + queue.append(bert_embedding_queue.get_embeddings(line)) + embedding_list = await asyncio.gather(*queue) + end_time = time.time() + print("time cost: %.4f" % (end_time - start_time)) + print("dimensions: %d" % len(embedding_list[0])) + await noawait.end() + +if __name__ == '__main__': + loop.run_until_complete(main()) \ No newline at end of file diff --git a/test/chatcomplete.py b/test/chatcomplete.py new file mode 100644 index 0000000..3fc840c --- /dev/null +++ b/test/chatcomplete.py @@ -0,0 +1,37 @@ +import traceback +import base +import local +from service.chat_complete import ChatCompleteService +from service.database import DatabaseService + +from service.tiktoken import TikTokenService + +async def main(): + dbs = await DatabaseService.create() + tiktoken = await TikTokenService.create() + + async with ChatCompleteService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as chat_complete: + question = "你是谁?" + question_tokens = await tiktoken.get_tokens(question) + + try: + prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, { + "distance_limit": 0.6, + "limit": 10 + }) + + print(prepare_res) + async def on_message(message: str): + # print(message) + pass + + res = await chat_complete.finish_chat_complete(on_message) + print(res) + except Exception as err: + print(err) + traceback.print_exc() + + await local.noawait.end() + +if __name__ == '__main__': + local.loop.run_until_complete(main()) \ No newline at end of file diff --git a/test.py b/test/embedding_search.py similarity index 94% rename from test.py rename to test/embedding_search.py index 6b8adda..5be0875 100644 --- a/test.py +++ b/test/embedding_search.py @@ -1,37 +1,38 @@ -import local -from service.database import DatabaseService - -from service.embedding_search import EmbeddingSearchService - -async def main(): - dbs = await DatabaseService.create() - - async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search: - await embedding_search.prepare_update_index() - - async def on_index_progress(current, length): - print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True) - - print("") - await embedding_search.update_page_index(on_index_progress) - print("") - - while True: - query = input("请输入要搜索的问题 (.exit 退出):") - if query == ".exit": - break - res, token_usage = await embedding_search.search(query, 5) - total_length = 0 - if res: - for one in res: - total_length += len(one["markdown"]) - print("%s, distance=%.4f" % (one["markdown"], one["distance"])) - else: - print("未搜索到相关内容") - - print("总长度:%d" % total_length) - - await local.noawait.end() - -if __name__ == '__main__': +import base +import local +from service.database import DatabaseService + +from service.embedding_search import EmbeddingSearchService + +async def main(): + dbs = await DatabaseService.create() + + async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search: + await embedding_search.prepare_update_index() + + async def on_index_progress(current, length): + print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True) + + print("") + await embedding_search.update_page_index(on_index_progress) + print("") + + while True: + query = input("请输入要搜索的问题 (.exit 退出):") + if query == ".exit": + break + res, token_usage = await embedding_search.search(query, 5) + total_length = 0 + if res: + for one in res: + total_length += len(one["markdown"]) + print("%s, distance=%.4f" % (one["markdown"], one["distance"])) + else: + print("未搜索到相关内容") + + print("总长度:%d" % total_length) + + await local.noawait.end() + +if __name__ == '__main__': local.loop.run_until_complete(main()) \ No newline at end of file diff --git a/test.md b/test/test.md similarity index 100% rename from test.md rename to test/test.md diff --git a/test_chatcomplete.js b/test/test_chatcomplete.js similarity index 100% rename from test_chatcomplete.js rename to test/test_chatcomplete.js diff --git a/test/timer.py b/test/timer.py new file mode 100644 index 0000000..b4c9e56 --- /dev/null +++ b/test/timer.py @@ -0,0 +1,20 @@ +import asyncio +import base +from local import loop, noawait + +async def test_timer1(): + print("timer1") + +async def test_timer2(): + print("timer2") + +async def main(): + timer_id = noawait.add_timer(test_timer1, 1) + timer_id = noawait.add_timer(test_timer2, 2) + print("Timer id: %d" % timer_id) + while True: + await asyncio.sleep(1) + await noawait.end() + +if __name__ == '__main__': + loop.run_until_complete(main()) \ No newline at end of file diff --git a/utils/web.py b/utils/web.py index 508642b..fcd22d2 100644 --- a/utils/web.py +++ b/utils/web.py @@ -1,5 +1,6 @@ from __future__ import annotations from functools import wraps +import json from typing import Any, Optional, Dict from aiohttp import web import jwt @@ -8,22 +9,32 @@ import uuid ParamRule = Dict[str, Any] -class ParamInvalidException(Exception): +class ParamInvalidException(web.HTTPBadRequest): def __init__(self, param_list: list[str], rules: dict[str, ParamRule]): self.code = "param_invalid" self.param_list = param_list self.rules = rules param_list_str = "'" + ("', '".join(param_list)) + "'" - super().__init__(f"Param invalid: {param_list_str}") + super().__init__(f"Param invalid: {param_list_str}", + content_type="application/json", + body=json.dumps({ + "status": -1, + "error": { + "code": self.code, + "message": f"Param invalid: {param_list_str}" + } + })) async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None): params: dict[str, Any] = {} + for key, value in request.match_info.items(): + params[key] = value for key, value in request.query.items(): params[key] = value if request.method == 'POST': if request.headers.get('content-type') == 'application/json': data = await request.json() - if data is not None and data is dict: + if data is not None and isinstance(data, dict): for key, value in data.items(): params[key] = value else: @@ -34,7 +45,7 @@ async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] if rules is not None: invalid_params: list[str] = [] for key, rule in rules.items(): - if "required" in rule and rule["required"] and params[key] is None: + if "required" in rule and rule["required"] and key not in params.keys(): invalid_params.append(key) continue @@ -50,9 +61,20 @@ async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] elif rule["type"] == float: params[key] = float(params[key]) elif rule["type"] == bool: - val = params[key].lower() - if val == "false" or val == "0": - params[key] = False + val = params[key] + if isinstance(val, bool): + params[key] = val + elif isinstance(val, str): + val = val.lower() + if val.lower() == "false" or val == "0": + params[key] = False + else: + params[key] = True + elif isinstance(val, int): + if val == 0: + params[key] = False + else: + params[key] = True else: params[key] = True except ValueError: