完成对话fork功能
parent
197e3e1f1a
commit
003a9a7948
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from local import noawait
|
||||
from typing import Optional, Callable, Union
|
||||
from service.chat_complete import ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
|
||||
from service.database import DatabaseService
|
||||
from service.embedding_search import EmbeddingSearchArgs
|
||||
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
|
||||
from service.tiktoken import TikTokenService
|
||||
import utils.web
|
||||
|
||||
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
|
||||
|
||||
class ChatCompleteTask:
|
||||
@staticmethod
|
||||
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
|
||||
return chat_complete_tasks.get(task_id)
|
||||
|
||||
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
|
||||
self.task_id = utils.web.generate_uuid()
|
||||
self.on_message: list[Callable] = []
|
||||
self.on_finished: list[Callable] = []
|
||||
self.on_error: list[Callable] = []
|
||||
self.chunks: list[str] = []
|
||||
|
||||
self.chat_complete_service: ChatCompleteService
|
||||
self.chat_complete: ChatCompleteService
|
||||
|
||||
self.dbs = dbs
|
||||
self.user_id = user_id
|
||||
self.page_title = page_title
|
||||
self.is_system = is_system
|
||||
|
||||
self.transatcion_id: Optional[str] = None
|
||||
self.point_cost: int = 0
|
||||
|
||||
self.is_finished = False
|
||||
self.finished_time: Optional[float] = None
|
||||
self.result: Optional[ChatCompleteServiceResponse] = None
|
||||
self.error: Optional[Exception] = None
|
||||
|
||||
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
|
||||
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
|
||||
self.tiktoken = await TikTokenService.create()
|
||||
|
||||
self.mwapi = MediaWikiApi.create()
|
||||
|
||||
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
|
||||
self.chat_complete = await self.chat_complete_service.__aenter__()
|
||||
|
||||
if await self.chat_complete.page_index_exists():
|
||||
question_tokens = await self.tiktoken.get_tokens(question)
|
||||
|
||||
extract_limit = embedding_search["limit"] or 10
|
||||
|
||||
self.transatcion_id: Optional[str] = None
|
||||
self.point_cost: int = 0
|
||||
if not self.is_system:
|
||||
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
|
||||
question_tokens, extract_limit)
|
||||
self.transatcion_id = usage_res["transaction_id"]
|
||||
self.point_cost = usage_res["point_cost"]
|
||||
|
||||
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
||||
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
|
||||
|
||||
return res
|
||||
else:
|
||||
await self._exit()
|
||||
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
||||
|
||||
async def _on_message(self, delta_message: str):
|
||||
self.chunks.append(delta_message)
|
||||
|
||||
for callback in self.on_message:
|
||||
try:
|
||||
await callback(delta_message)
|
||||
except Exception as e:
|
||||
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
async def _on_finished(self):
|
||||
for callback in self.on_finished:
|
||||
try:
|
||||
await callback(self.result)
|
||||
except Exception as e:
|
||||
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
async def _on_error(self, err: Exception):
|
||||
self.error = err
|
||||
for callback in self.on_error:
|
||||
try:
|
||||
await callback(err)
|
||||
except Exception as e:
|
||||
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
async def run(self) -> ChatCompleteServiceResponse:
|
||||
chat_complete_tasks[self.task_id] = self
|
||||
try:
|
||||
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
|
||||
|
||||
await self.chat_complete.set_latest_point_cost(self.point_cost)
|
||||
|
||||
self.result = chat_res
|
||||
|
||||
if self.transatcion_id:
|
||||
await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"])
|
||||
|
||||
await self._on_finished()
|
||||
except Exception as e:
|
||||
err_msg = f"Error while processing chat complete request: {e}"
|
||||
|
||||
print(err_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if self.transatcion_id:
|
||||
await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg)
|
||||
|
||||
await self._on_error(e)
|
||||
finally:
|
||||
await self._exit()
|
||||
|
||||
async def _exit(self):
|
||||
await self.chat_complete_service.__aexit__(None, None, None)
|
||||
del chat_complete_tasks[self.task_id]
|
||||
self.is_finished = True
|
||||
self.finished_time = time.time()
|
||||
|
||||
TASK_EXPIRE_TIME = 60 * 10
|
||||
|
||||
async def chat_complete_task_gc():
|
||||
now = time.time()
|
||||
for task_id in chat_complete_tasks.keys():
|
||||
task = chat_complete_tasks[task_id]
|
||||
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
|
||||
del chat_complete_tasks[task_id]
|
||||
|
||||
noawait.add_timer(chat_complete_task_gc, 60)
|
@ -1,4 +1,37 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypeVar
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from service.database import DatabaseService
|
||||
|
||||
class BaseModel(DeclarativeBase):
|
||||
pass
|
||||
pass
|
||||
|
||||
class BaseHelper:
|
||||
def __init__(self, dbs: DatabaseService):
|
||||
self.dbs = dbs
|
||||
self.initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.initialized:
|
||||
self.create_session = self.dbs.create_session
|
||||
self.session = self.dbs.create_session()
|
||||
await self.session.__aenter__()
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.session.__aexit__(exc_type, exc, tb)
|
||||
pass
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
def clone_model(model: T) -> T:
|
||||
data_dict = {}
|
||||
for c in sqlalchemy.inspect(model).mapper.column_attrs:
|
||||
if c.key == "id":
|
||||
continue
|
||||
data_dict[c.key] = getattr(model, c.key)
|
||||
return model.__class__(**data_dict)
|
@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
import sqlalchemy
|
||||
from api.model.base import BaseHelper, BaseModel
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
||||
|
||||
class BotPersonaModel(BaseModel):
|
||||
__tablename__ = "chat_complete_bot_persona"
|
||||
|
||||
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||||
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||||
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||||
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||||
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
|
||||
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
|
||||
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||||
|
||||
class BotPersonaHelper(BaseHelper):
|
||||
async def add(self, obj: BotPersonaModel):
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def update(self, obj: BotPersonaModel):
|
||||
obj = await self.session.merge(obj)
|
||||
await self.session.commit()
|
||||
return obj
|
||||
|
||||
async def get_list(self):
|
||||
stmt = select(BotPersonaModel).with_only_columns([
|
||||
BotPersonaModel.id,
|
||||
BotPersonaModel.bot_name,
|
||||
BotPersonaModel.bot_avatar,
|
||||
BotPersonaModel.bot_description
|
||||
])
|
||||
return await self.session.scalars(stmt)
|
||||
|
||||
async def find_by_id(self, id: int):
|
||||
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
async def find_by_bot_id(self, bot_id: str):
|
||||
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
|
||||
return await self.session.scalar(stmt)
|
Loading…
Reference in New Issue