完成对话fork功能

master
落雨楓 2 years ago
parent 197e3e1f1a
commit 003a9a7948

@ -3,6 +3,8 @@ import asyncio
import sys
import time
import traceback
from api.controller.task.ChatCompleteTask import ChatCompleteTask
from api.model.base import clone_model
from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait
from typing import Optional, Callable, TypedDict
@ -17,131 +19,6 @@ from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException,
from service.tiktoken import TikTokenService
import utils.web
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService
self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
self.is_finished = False
self.finished_time: Optional[float] = None
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None):
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
question_tokens, extract_limit)
self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
return res
else:
await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message:
try:
await callback(delta_message)
except Exception as e:
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error:
try:
await callback(err)
except Exception as e:
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def run(self):
try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
self.result = chat_res
if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"])
await self._on_finished()
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e)
finally:
await self._exit()
async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()
TASK_EXPIRE_TIME = 60 * 10
async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)
class ChatComplete:
@staticmethod
@utils.web.token_auth
@ -238,6 +115,118 @@ class ChatComplete:
return await utils.web.api_response(1, chunk_dict, request=request)
@staticmethod
@utils.web.token_auth
async def fork_conversation(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int,
},
"id": {
"required": True,
"type": int,
},
"message_id": {
"required": False,
"type": str
},
"new_title": {
"required": False,
"type": str
}
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
conversation_id: int = params.get("id")
packed_message_id: str = params.get("message_id")
new_title = params.get("new_title")
if packed_message_id is not None:
(chunk_id, msg_id) = packed_message_id.split(",")
chunk_id = int(chunk_id)
else:
chunk_id = None
msg_id = None
db = await DatabaseService.create(request.app)
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
conversation_info = await conversation_helper.find_by_id(conversation_id)
if conversation_info is None:
return await utils.web.api_response(-1, error={
"code": "conversation-not-found",
"message": "Conversation not found."
}, http_status=404, request=request)
if conversation_info.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
# Clone selected chunk
if chunk_id is not None:
chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
if chunk_info is None or chunk_info.conversation_id != conversation_id:
return await utils.web.api_response(-1, error={
"code": "conversation-chunk-not-found",
"message": "Conversation chunk not found."
}, http_status=404, request=request)
else:
chunk_info = await conversation_chunk_helper.get_newest_chunk(conversation_id)
new_conversation: ConversationModel = clone_model(conversation_info)
if new_title is not None:
new_conversation.title = new_title
new_conversation = await conversation_helper.add(new_conversation)
if chunk_info is not None:
new_chunk: ConversationChunkModel = clone_model(chunk_info)
new_chunk.conversation_id = new_conversation.id
if msg_id is not None:
# Remove message after selected message
split_message_pos = None
for i in range(0, len(new_chunk.message_data)):
msg_data = new_chunk.message_data[i]
if msg_data["id"] == msg_id:
split_message_pos = i
break
new_chunk.message_data = new_chunk.message_data[0:split_message_pos + 1]
new_chunk.message_data.insert(0, {
"id": utils.web.generate_uuid(),
"role": "notice",
"type": "forked",
"data": {
"original_conversation_id": conversation_info.id,
"original_title": conversation_info.title,
}
})
# Update conversation description
last_assistant_message = None
for msg in new_chunk.message_data:
if msg["role"] == "assistant":
last_assistant_message = msg
if last_assistant_message is not None:
new_conversation.description = last_assistant_message["content"][0:150]
conversation_helper.update(new_conversation)
new_chunk = await conversation_chunk_helper.add(new_chunk)
return await utils.web.api_response(1, {
"conversation_id": new_conversation.id,
}, request=request)
@staticmethod
@utils.web.token_auth
async def get_tokens(request: web.Request):
@ -349,8 +338,6 @@ class ChatComplete:
"limit": extract_limit,
"in_collection": in_collection,
})
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
noawait.add_task(chat_complete_task.run())
@ -404,7 +391,7 @@ class ChatComplete:
task_id = params.get("task_id")
task = chat_complete_tasks.get(task_id)
task = ChatCompleteTask.get_by_id(task_id)
if task is None:
await ws.send_json({
'event': 'error',

@ -0,0 +1,142 @@
from __future__ import annotations
import sys
import time
import traceback
from local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService
import utils.web
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
@staticmethod
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
return chat_complete_tasks.get(task_id)
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService
self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
self.is_finished = False
self.finished_time: Optional[float] = None
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
question_tokens, extract_limit)
self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
return res
else:
await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message:
try:
await callback(delta_message)
except Exception as e:
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error:
try:
await callback(err)
except Exception as e:
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
traceback.print_exc()
async def run(self) -> ChatCompleteServiceResponse:
chat_complete_tasks[self.task_id] = self
try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
self.result = chat_res
if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"])
await self._on_finished()
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e)
finally:
await self._exit()
async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()
TASK_EXPIRE_TIME = 60 * 10
async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)

@ -1,4 +1,37 @@
from __future__ import annotations
from typing import TypeVar
import sqlalchemy
from sqlalchemy.orm import DeclarativeBase
from service.database import DatabaseService
class BaseModel(DeclarativeBase):
pass
pass
class BaseHelper:
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
async def __aenter__(self):
if not self.initialized:
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb)
pass
T = TypeVar("T", bound=BaseModel)
def clone_model(model: T) -> T:
data_dict = {}
for c in sqlalchemy.inspect(model).mapper.column_attrs:
if c.key == "id":
continue
data_dict[c.key] = getattr(model, c.key)
return model.__class__(**data_dict)

@ -0,0 +1,48 @@
from __future__ import annotations
import sqlalchemy
from api.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
class BotPersonaModel(BaseModel):
__tablename__ = "chat_complete_bot_persona"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaModel):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
async def update(self, obj: BotPersonaModel):
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def get_list(self):
stmt = select(BotPersonaModel).with_only_columns([
BotPersonaModel.id,
BotPersonaModel.bot_name,
BotPersonaModel.bot_avatar,
BotPersonaModel.bot_description
])
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
return await self.session.scalar(stmt)
async def find_by_bot_id(self, bot_id: str):
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
return await self.session.scalar(stmt)

@ -5,7 +5,7 @@ import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
from api.model.base import BaseModel
from api.model.base import BaseHelper, BaseModel
from api.model.toolkit_ui.conversation import ConversationModel
from service.database import DatabaseService
from service.event import EventService
@ -20,24 +20,7 @@ class ConversationChunkModel(BaseModel):
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
class ConversationChunkHelper:
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
async def __aenter__(self):
if not self.initialized:
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb)
pass
class ConversationChunkHelper(BaseHelper):
async def add(self, obj: ConversationChunkModel):
obj.updated_at = int(time.time())
self.session.add(obj)

@ -3,7 +3,7 @@ import sqlalchemy
from sqlalchemy import select, update, delete
from sqlalchemy.orm import mapped_column, Mapped
from api.model.base import BaseModel
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService
class TitleCollectionModel(BaseModel):
@ -13,23 +13,7 @@ class TitleCollectionModel(BaseModel):
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
class TitleCollectionHelper:
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
async def __aenter__(self):
if not self.initialized:
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb)
class TitleCollectionHelper(BaseHelper):
async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]:
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)
result = await self.session.scalar(stmt)

@ -6,7 +6,7 @@ import sqlalchemy
from sqlalchemy import update
from sqlalchemy.orm import mapped_column, relationship, Mapped
from api.model.base import BaseModel
from api.model.base import BaseHelper, BaseModel
from api.model.toolkit_ui.page_title import PageTitleModel
from service.database import DatabaseService
@ -30,25 +30,7 @@ class ConversationModel(BaseModel):
page_info: Mapped[PageTitleModel] = relationship("PageTitleModel", lazy="joined")
class ConversationHelper:
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
async def __aenter__(self):
if not self.initialized:
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb)
pass
class ConversationHelper(BaseHelper):
async def add(self, obj: ConversationModel):
obj.updated_at = int(time.time())
self.session.add(obj)

@ -6,7 +6,7 @@ import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, Mapped
from api.model.base import BaseModel
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService
@ -20,25 +20,7 @@ class PageTitleModel(BaseModel):
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
class PageTitleHelper:
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
async def __aenter__(self):
if not self.initialized:
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb)
pass
class PageTitleHelper(BaseHelper):
async def find_by_page_id(self, page_id: int):
stmt = select(PageTitleModel).where(PageTitleModel.page_id == page_id)
return await self.session.scalar(stmt)

@ -31,6 +31,7 @@ def init(app: web.Application):
web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk),
web.route('POST', '/chatcomplete/conversation/fork', ChatComplete.fork_conversation),
web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost),

@ -11,6 +11,7 @@ from service.mediawiki_api import MediaWikiApi
from api.model.base import BaseModel
from api.model.toolkit_ui.conversation import ConversationModel as _
from api.model.chat_complete.conversation import ConversationChunkModel as _
from api.model.chat_complete.bot_persona import BotPersonaModel as _
from api.model.embedding_search.title_collection import TitleCollectionModel as _
from api.model.embedding_search.title_index import TitleIndexModel as _

Loading…
Cancel
Save