From 003a9a79486784b9ec37f679f74034d791295b53 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Wed, 21 Jun 2023 05:34:54 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=AF=B9=E8=AF=9Dfork?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controller/ChatComplete.py | 243 +++++++++--------- api/controller/task/ChatCompleteTask.py | 142 ++++++++++ api/model/base.py | 35 ++- api/model/chat_complete/bot_persona.py | 48 ++++ api/model/chat_complete/conversation.py | 21 +- .../embedding_search/title_collection.py | 20 +- api/model/toolkit_ui/conversation.py | 22 +- api/model/toolkit_ui/page_title.py | 22 +- api/route.py | 1 + main.py | 1 + 10 files changed, 349 insertions(+), 206 deletions(-) create mode 100644 api/controller/task/ChatCompleteTask.py create mode 100644 api/model/chat_complete/bot_persona.py diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index a751d1e..9b5a9c0 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -3,6 +3,8 @@ import asyncio import sys import time import traceback +from api.controller.task.ChatCompleteTask import ChatCompleteTask +from api.model.base import clone_model from api.model.toolkit_ui.conversation import ConversationHelper from local import noawait from typing import Optional, Callable, TypedDict @@ -17,131 +19,6 @@ from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, from service.tiktoken import TikTokenService import utils.web -chat_complete_tasks: dict[str, ChatCompleteTask] = {} - -class ChatCompleteTask: - def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False): - self.task_id = utils.web.generate_uuid() - self.on_message: list[Callable] = [] - self.on_finished: list[Callable] = [] - self.on_error: list[Callable] = [] - self.chunks: list[str] = [] - - self.chat_complete_service: ChatCompleteService - self.chat_complete: ChatCompleteService - - self.dbs = dbs - self.user_id = user_id - self.page_title = page_title - self.is_system = is_system - - self.transatcion_id: Optional[str] = None - self.point_cost: int = 0 - - self.is_finished = False - self.finished_time: Optional[float] = None - self.result: Optional[ChatCompleteServiceResponse] = None - self.error: Optional[Exception] = None - - async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None, - embedding_search: Optional[EmbeddingSearchArgs] = None): - self.tiktoken = await TikTokenService.create() - - self.mwapi = MediaWikiApi.create() - - self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) - self.chat_complete = await self.chat_complete_service.__aenter__() - - if await self.chat_complete.page_index_exists(): - question_tokens = await self.tiktoken.get_tokens(question) - - extract_limit = embedding_search["limit"] or 10 - - self.transatcion_id: Optional[str] = None - self.point_cost: int = 0 - if not self.is_system: - usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete", - question_tokens, extract_limit) - self.transatcion_id = usage_res["transaction_id"] - self.point_cost = usage_res["point_cost"] - - res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, - user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search) - - return res - else: - await self._exit() - raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) - - async def _on_message(self, delta_message: str): - self.chunks.append(delta_message) - - for callback in self.on_message: - try: - await callback(delta_message) - except Exception as e: - print("Error while processing on_message callback: %s" % e, file=sys.stderr) - traceback.print_exc() - - async def _on_finished(self): - for callback in self.on_finished: - try: - await callback(self.result) - except Exception as e: - print("Error while processing on_finished callback: %s" % e, file=sys.stderr) - traceback.print_exc() - - async def _on_error(self, err: Exception): - self.error = err - for callback in self.on_error: - try: - await callback(err) - except Exception as e: - print("Error while processing on_error callback: %s" % e, file=sys.stderr) - traceback.print_exc() - - async def run(self): - try: - chat_res = await self.chat_complete.finish_chat_complete(self._on_message) - - await self.chat_complete.set_latest_point_cost(self.point_cost) - - self.result = chat_res - - if self.transatcion_id: - await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"]) - - await self._on_finished() - except Exception as e: - err_msg = f"Error while processing chat complete request: {e}" - - print(err_msg, file=sys.stderr) - traceback.print_exc() - - if self.transatcion_id: - await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg) - - await self._on_error(e) - finally: - await self._exit() - - async def _exit(self): - await self.chat_complete_service.__aexit__(None, None, None) - del chat_complete_tasks[self.task_id] - self.is_finished = True - self.finished_time = time.time() - -TASK_EXPIRE_TIME = 60 * 10 - -async def chat_complete_task_gc(): - now = time.time() - for task_id in chat_complete_tasks.keys(): - task = chat_complete_tasks[task_id] - if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME: - del chat_complete_tasks[task_id] - -noawait.add_timer(chat_complete_task_gc, 60) - class ChatComplete: @staticmethod @utils.web.token_auth @@ -238,6 +115,118 @@ class ChatComplete: return await utils.web.api_response(1, chunk_dict, request=request) + @staticmethod + @utils.web.token_auth + async def fork_conversation(request: web.Request): + params = await utils.web.get_param(request, { + "user_id": { + "required": False, + "type": int, + }, + "id": { + "required": True, + "type": int, + }, + "message_id": { + "required": False, + "type": str + }, + "new_title": { + "required": False, + "type": str + } + }) + + if request.get("caller") == "user": + user_id = request.get("user") + else: + user_id = params.get("user_id") + + conversation_id: int = params.get("id") + packed_message_id: str = params.get("message_id") + new_title = params.get("new_title") + + if packed_message_id is not None: + (chunk_id, msg_id) = packed_message_id.split(",") + chunk_id = int(chunk_id) + else: + chunk_id = None + msg_id = None + + db = await DatabaseService.create(request.app) + async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper: + conversation_info = await conversation_helper.find_by_id(conversation_id) + if conversation_info is None: + return await utils.web.api_response(-1, error={ + "code": "conversation-not-found", + "message": "Conversation not found." + }, http_status=404, request=request) + + if conversation_info.user_id != user_id: + return await utils.web.api_response(-1, error={ + "code": "permission-denied", + "message": "Permission denied." + }, http_status=403, request=request) + + # Clone selected chunk + if chunk_id is not None: + chunk_info = await conversation_chunk_helper.find_by_id(chunk_id) + if chunk_info is None or chunk_info.conversation_id != conversation_id: + return await utils.web.api_response(-1, error={ + "code": "conversation-chunk-not-found", + "message": "Conversation chunk not found." + }, http_status=404, request=request) + else: + chunk_info = await conversation_chunk_helper.get_newest_chunk(conversation_id) + + new_conversation: ConversationModel = clone_model(conversation_info) + + if new_title is not None: + new_conversation.title = new_title + + new_conversation = await conversation_helper.add(new_conversation) + + if chunk_info is not None: + new_chunk: ConversationChunkModel = clone_model(chunk_info) + new_chunk.conversation_id = new_conversation.id + + if msg_id is not None: + # Remove message after selected message + split_message_pos = None + for i in range(0, len(new_chunk.message_data)): + msg_data = new_chunk.message_data[i] + if msg_data["id"] == msg_id: + split_message_pos = i + break + + new_chunk.message_data = new_chunk.message_data[0:split_message_pos + 1] + + new_chunk.message_data.insert(0, { + "id": utils.web.generate_uuid(), + "role": "notice", + "type": "forked", + "data": { + "original_conversation_id": conversation_info.id, + "original_title": conversation_info.title, + } + }) + + # Update conversation description + last_assistant_message = None + for msg in new_chunk.message_data: + if msg["role"] == "assistant": + last_assistant_message = msg + + if last_assistant_message is not None: + new_conversation.description = last_assistant_message["content"][0:150] + conversation_helper.update(new_conversation) + + new_chunk = await conversation_chunk_helper.add(new_chunk) + + return await utils.web.api_response(1, { + "conversation_id": new_conversation.id, + }, request=request) + @staticmethod @utils.web.token_auth async def get_tokens(request: web.Request): @@ -349,8 +338,6 @@ class ChatComplete: "limit": extract_limit, "in_collection": in_collection, }) - - chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task noawait.add_task(chat_complete_task.run()) @@ -404,7 +391,7 @@ class ChatComplete: task_id = params.get("task_id") - task = chat_complete_tasks.get(task_id) + task = ChatCompleteTask.get_by_id(task_id) if task is None: await ws.send_json({ 'event': 'error', diff --git a/api/controller/task/ChatCompleteTask.py b/api/controller/task/ChatCompleteTask.py new file mode 100644 index 0000000..0952414 --- /dev/null +++ b/api/controller/task/ChatCompleteTask.py @@ -0,0 +1,142 @@ +from __future__ import annotations +import sys +import time +import traceback +from local import noawait +from typing import Optional, Callable, Union +from service.chat_complete import ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse +from service.database import DatabaseService +from service.embedding_search import EmbeddingSearchArgs +from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException +from service.tiktoken import TikTokenService +import utils.web + +chat_complete_tasks: dict[str, ChatCompleteTask] = {} + +class ChatCompleteTask: + @staticmethod + def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]: + return chat_complete_tasks.get(task_id) + + def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False): + self.task_id = utils.web.generate_uuid() + self.on_message: list[Callable] = [] + self.on_finished: list[Callable] = [] + self.on_error: list[Callable] = [] + self.chunks: list[str] = [] + + self.chat_complete_service: ChatCompleteService + self.chat_complete: ChatCompleteService + + self.dbs = dbs + self.user_id = user_id + self.page_title = page_title + self.is_system = is_system + + self.transatcion_id: Optional[str] = None + self.point_cost: int = 0 + + self.is_finished = False + self.finished_time: Optional[float] = None + self.result: Optional[ChatCompleteServiceResponse] = None + self.error: Optional[Exception] = None + + async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None, + embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse: + self.tiktoken = await TikTokenService.create() + + self.mwapi = MediaWikiApi.create() + + self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) + self.chat_complete = await self.chat_complete_service.__aenter__() + + if await self.chat_complete.page_index_exists(): + question_tokens = await self.tiktoken.get_tokens(question) + + extract_limit = embedding_search["limit"] or 10 + + self.transatcion_id: Optional[str] = None + self.point_cost: int = 0 + if not self.is_system: + usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete", + question_tokens, extract_limit) + self.transatcion_id = usage_res["transaction_id"] + self.point_cost = usage_res["point_cost"] + + res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, + user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search) + + return res + else: + await self._exit() + raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) + + async def _on_message(self, delta_message: str): + self.chunks.append(delta_message) + + for callback in self.on_message: + try: + await callback(delta_message) + except Exception as e: + print("Error while processing on_message callback: %s" % e, file=sys.stderr) + traceback.print_exc() + + async def _on_finished(self): + for callback in self.on_finished: + try: + await callback(self.result) + except Exception as e: + print("Error while processing on_finished callback: %s" % e, file=sys.stderr) + traceback.print_exc() + + async def _on_error(self, err: Exception): + self.error = err + for callback in self.on_error: + try: + await callback(err) + except Exception as e: + print("Error while processing on_error callback: %s" % e, file=sys.stderr) + traceback.print_exc() + + async def run(self) -> ChatCompleteServiceResponse: + chat_complete_tasks[self.task_id] = self + try: + chat_res = await self.chat_complete.finish_chat_complete(self._on_message) + + await self.chat_complete.set_latest_point_cost(self.point_cost) + + self.result = chat_res + + if self.transatcion_id: + await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"]) + + await self._on_finished() + except Exception as e: + err_msg = f"Error while processing chat complete request: {e}" + + print(err_msg, file=sys.stderr) + traceback.print_exc() + + if self.transatcion_id: + await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg) + + await self._on_error(e) + finally: + await self._exit() + + async def _exit(self): + await self.chat_complete_service.__aexit__(None, None, None) + del chat_complete_tasks[self.task_id] + self.is_finished = True + self.finished_time = time.time() + +TASK_EXPIRE_TIME = 60 * 10 + +async def chat_complete_task_gc(): + now = time.time() + for task_id in chat_complete_tasks.keys(): + task = chat_complete_tasks[task_id] + if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME: + del chat_complete_tasks[task_id] + +noawait.add_timer(chat_complete_task_gc, 60) \ No newline at end of file diff --git a/api/model/base.py b/api/model/base.py index 02aadde..e3fccac 100644 --- a/api/model/base.py +++ b/api/model/base.py @@ -1,4 +1,37 @@ +from __future__ import annotations +from typing import TypeVar +import sqlalchemy from sqlalchemy.orm import DeclarativeBase +from service.database import DatabaseService + class BaseModel(DeclarativeBase): - pass \ No newline at end of file + pass + +class BaseHelper: + def __init__(self, dbs: DatabaseService): + self.dbs = dbs + self.initialized = False + + async def __aenter__(self): + if not self.initialized: + self.create_session = self.dbs.create_session + self.session = self.dbs.create_session() + await self.session.__aenter__() + self.initialized = True + + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.session.__aexit__(exc_type, exc, tb) + pass + +T = TypeVar("T", bound=BaseModel) + +def clone_model(model: T) -> T: + data_dict = {} + for c in sqlalchemy.inspect(model).mapper.column_attrs: + if c.key == "id": + continue + data_dict[c.key] = getattr(model, c.key) + return model.__class__(**data_dict) \ No newline at end of file diff --git a/api/model/chat_complete/bot_persona.py b/api/model/chat_complete/bot_persona.py new file mode 100644 index 0000000..f7362b1 --- /dev/null +++ b/api/model/chat_complete/bot_persona.py @@ -0,0 +1,48 @@ +from __future__ import annotations +import sqlalchemy +from api.model.base import BaseHelper, BaseModel + +import sqlalchemy +from sqlalchemy import select, update +from sqlalchemy.orm import mapped_column, relationship, Mapped + +class BotPersonaModel(BaseModel): + __tablename__ = "chat_complete_bot_persona" + + id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) + bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) + bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) + message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) + default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + +class BotPersonaHelper(BaseHelper): + async def add(self, obj: BotPersonaModel): + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj + + async def update(self, obj: BotPersonaModel): + obj = await self.session.merge(obj) + await self.session.commit() + return obj + + async def get_list(self): + stmt = select(BotPersonaModel).with_only_columns([ + BotPersonaModel.id, + BotPersonaModel.bot_name, + BotPersonaModel.bot_avatar, + BotPersonaModel.bot_description + ]) + return await self.session.scalars(stmt) + + async def find_by_id(self, id: int): + stmt = select(BotPersonaModel).where(BotPersonaModel.id == id) + return await self.session.scalar(stmt) + + async def find_by_bot_id(self, bot_id: str): + stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id) + return await self.session.scalar(stmt) \ No newline at end of file diff --git a/api/model/chat_complete/conversation.py b/api/model/chat_complete/conversation.py index f13c02f..f29096a 100644 --- a/api/model/chat_complete/conversation.py +++ b/api/model/chat_complete/conversation.py @@ -5,7 +5,7 @@ import sqlalchemy from sqlalchemy import select, update from sqlalchemy.orm import mapped_column, relationship, Mapped -from api.model.base import BaseModel +from api.model.base import BaseHelper, BaseModel from api.model.toolkit_ui.conversation import ConversationModel from service.database import DatabaseService from service.event import EventService @@ -20,24 +20,7 @@ class ConversationChunkModel(BaseModel): tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0) updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) -class ConversationChunkHelper: - def __init__(self, dbs: DatabaseService): - self.dbs = dbs - self.initialized = False - - async def __aenter__(self): - if not self.initialized: - self.create_session = self.dbs.create_session - self.session = self.dbs.create_session() - await self.session.__aenter__() - self.initialized = True - - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.session.__aexit__(exc_type, exc, tb) - pass - +class ConversationChunkHelper(BaseHelper): async def add(self, obj: ConversationChunkModel): obj.updated_at = int(time.time()) self.session.add(obj) diff --git a/api/model/embedding_search/title_collection.py b/api/model/embedding_search/title_collection.py index fb0302b..fad054a 100644 --- a/api/model/embedding_search/title_collection.py +++ b/api/model/embedding_search/title_collection.py @@ -3,7 +3,7 @@ import sqlalchemy from sqlalchemy import select, update, delete from sqlalchemy.orm import mapped_column, Mapped -from api.model.base import BaseModel +from api.model.base import BaseHelper, BaseModel from service.database import DatabaseService class TitleCollectionModel(BaseModel): @@ -13,23 +13,7 @@ class TitleCollectionModel(BaseModel): title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True) -class TitleCollectionHelper: - def __init__(self, dbs: DatabaseService): - self.dbs = dbs - self.initialized = False - - async def __aenter__(self): - if not self.initialized: - self.create_session = self.dbs.create_session - self.session = self.dbs.create_session() - await self.session.__aenter__() - self.initialized = True - - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.session.__aexit__(exc_type, exc, tb) - +class TitleCollectionHelper(BaseHelper): async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]: stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title) result = await self.session.scalar(stmt) diff --git a/api/model/toolkit_ui/conversation.py b/api/model/toolkit_ui/conversation.py index 0f5d0f1..8da1c2b 100644 --- a/api/model/toolkit_ui/conversation.py +++ b/api/model/toolkit_ui/conversation.py @@ -6,7 +6,7 @@ import sqlalchemy from sqlalchemy import update from sqlalchemy.orm import mapped_column, relationship, Mapped -from api.model.base import BaseModel +from api.model.base import BaseHelper, BaseModel from api.model.toolkit_ui.page_title import PageTitleModel from service.database import DatabaseService @@ -30,25 +30,7 @@ class ConversationModel(BaseModel): page_info: Mapped[PageTitleModel] = relationship("PageTitleModel", lazy="joined") -class ConversationHelper: - def __init__(self, dbs: DatabaseService): - self.dbs = dbs - self.initialized = False - - async def __aenter__(self): - if not self.initialized: - self.create_session = self.dbs.create_session - self.session = self.dbs.create_session() - await self.session.__aenter__() - - self.initialized = True - - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.session.__aexit__(exc_type, exc, tb) - pass - +class ConversationHelper(BaseHelper): async def add(self, obj: ConversationModel): obj.updated_at = int(time.time()) self.session.add(obj) diff --git a/api/model/toolkit_ui/page_title.py b/api/model/toolkit_ui/page_title.py index 984b4e4..a395e43 100644 --- a/api/model/toolkit_ui/page_title.py +++ b/api/model/toolkit_ui/page_title.py @@ -6,7 +6,7 @@ import sqlalchemy from sqlalchemy import select, update from sqlalchemy.orm import mapped_column, Mapped -from api.model.base import BaseModel +from api.model.base import BaseHelper, BaseModel from service.database import DatabaseService @@ -20,25 +20,7 @@ class PageTitleModel(BaseModel): updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) -class PageTitleHelper: - def __init__(self, dbs: DatabaseService): - self.dbs = dbs - self.initialized = False - - async def __aenter__(self): - if not self.initialized: - self.create_session = self.dbs.create_session - self.session = self.dbs.create_session() - await self.session.__aenter__() - - self.initialized = True - - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.session.__aexit__(exc_type, exc, tb) - pass - +class PageTitleHelper(BaseHelper): async def find_by_page_id(self, page_id: int): stmt = select(PageTitleModel).where(PageTitleModel.page_id == page_id) return await self.session.scalar(stmt) diff --git a/api/route.py b/api/route.py index dfb3b54..e6a2eda 100644 --- a/api/route.py +++ b/api/route.py @@ -31,6 +31,7 @@ def init(app: web.Application): web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list), web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk), + web.route('POST', '/chatcomplete/conversation/fork', ChatComplete.fork_conversation), web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete), web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost), diff --git a/main.py b/main.py index 0cc4471..562e183 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ from service.mediawiki_api import MediaWikiApi from api.model.base import BaseModel from api.model.toolkit_ui.conversation import ConversationModel as _ from api.model.chat_complete.conversation import ConversationChunkModel as _ +from api.model.chat_complete.bot_persona import BotPersonaModel as _ from api.model.embedding_search.title_collection import TitleCollectionModel as _ from api.model.embedding_search.title_index import TitleIndexModel as _