完成对话fork功能
parent
197e3e1f1a
commit
003a9a7948
@ -0,0 +1,142 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from local import noawait
|
||||||
|
from typing import Optional, Callable, Union
|
||||||
|
from service.chat_complete import ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
|
||||||
|
from service.database import DatabaseService
|
||||||
|
from service.embedding_search import EmbeddingSearchArgs
|
||||||
|
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
|
||||||
|
from service.tiktoken import TikTokenService
|
||||||
|
import utils.web
|
||||||
|
|
||||||
|
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
|
||||||
|
|
||||||
|
class ChatCompleteTask:
|
||||||
|
@staticmethod
|
||||||
|
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
|
||||||
|
return chat_complete_tasks.get(task_id)
|
||||||
|
|
||||||
|
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
|
||||||
|
self.task_id = utils.web.generate_uuid()
|
||||||
|
self.on_message: list[Callable] = []
|
||||||
|
self.on_finished: list[Callable] = []
|
||||||
|
self.on_error: list[Callable] = []
|
||||||
|
self.chunks: list[str] = []
|
||||||
|
|
||||||
|
self.chat_complete_service: ChatCompleteService
|
||||||
|
self.chat_complete: ChatCompleteService
|
||||||
|
|
||||||
|
self.dbs = dbs
|
||||||
|
self.user_id = user_id
|
||||||
|
self.page_title = page_title
|
||||||
|
self.is_system = is_system
|
||||||
|
|
||||||
|
self.transatcion_id: Optional[str] = None
|
||||||
|
self.point_cost: int = 0
|
||||||
|
|
||||||
|
self.is_finished = False
|
||||||
|
self.finished_time: Optional[float] = None
|
||||||
|
self.result: Optional[ChatCompleteServiceResponse] = None
|
||||||
|
self.error: Optional[Exception] = None
|
||||||
|
|
||||||
|
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
|
||||||
|
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
|
||||||
|
self.tiktoken = await TikTokenService.create()
|
||||||
|
|
||||||
|
self.mwapi = MediaWikiApi.create()
|
||||||
|
|
||||||
|
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
|
||||||
|
self.chat_complete = await self.chat_complete_service.__aenter__()
|
||||||
|
|
||||||
|
if await self.chat_complete.page_index_exists():
|
||||||
|
question_tokens = await self.tiktoken.get_tokens(question)
|
||||||
|
|
||||||
|
extract_limit = embedding_search["limit"] or 10
|
||||||
|
|
||||||
|
self.transatcion_id: Optional[str] = None
|
||||||
|
self.point_cost: int = 0
|
||||||
|
if not self.is_system:
|
||||||
|
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
|
||||||
|
question_tokens, extract_limit)
|
||||||
|
self.transatcion_id = usage_res["transaction_id"]
|
||||||
|
self.point_cost = usage_res["point_cost"]
|
||||||
|
|
||||||
|
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
||||||
|
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
|
||||||
|
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
await self._exit()
|
||||||
|
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
||||||
|
|
||||||
|
async def _on_message(self, delta_message: str):
|
||||||
|
self.chunks.append(delta_message)
|
||||||
|
|
||||||
|
for callback in self.on_message:
|
||||||
|
try:
|
||||||
|
await callback(delta_message)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _on_finished(self):
|
||||||
|
for callback in self.on_finished:
|
||||||
|
try:
|
||||||
|
await callback(self.result)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _on_error(self, err: Exception):
|
||||||
|
self.error = err
|
||||||
|
for callback in self.on_error:
|
||||||
|
try:
|
||||||
|
await callback(err)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def run(self) -> ChatCompleteServiceResponse:
|
||||||
|
chat_complete_tasks[self.task_id] = self
|
||||||
|
try:
|
||||||
|
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
|
||||||
|
|
||||||
|
await self.chat_complete.set_latest_point_cost(self.point_cost)
|
||||||
|
|
||||||
|
self.result = chat_res
|
||||||
|
|
||||||
|
if self.transatcion_id:
|
||||||
|
await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"])
|
||||||
|
|
||||||
|
await self._on_finished()
|
||||||
|
except Exception as e:
|
||||||
|
err_msg = f"Error while processing chat complete request: {e}"
|
||||||
|
|
||||||
|
print(err_msg, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if self.transatcion_id:
|
||||||
|
await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg)
|
||||||
|
|
||||||
|
await self._on_error(e)
|
||||||
|
finally:
|
||||||
|
await self._exit()
|
||||||
|
|
||||||
|
async def _exit(self):
|
||||||
|
await self.chat_complete_service.__aexit__(None, None, None)
|
||||||
|
del chat_complete_tasks[self.task_id]
|
||||||
|
self.is_finished = True
|
||||||
|
self.finished_time = time.time()
|
||||||
|
|
||||||
|
TASK_EXPIRE_TIME = 60 * 10
|
||||||
|
|
||||||
|
async def chat_complete_task_gc():
|
||||||
|
now = time.time()
|
||||||
|
for task_id in chat_complete_tasks.keys():
|
||||||
|
task = chat_complete_tasks[task_id]
|
||||||
|
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
|
||||||
|
del chat_complete_tasks[task_id]
|
||||||
|
|
||||||
|
noawait.add_timer(chat_complete_task_gc, 60)
|
@ -1,4 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TypeVar
|
||||||
|
import sqlalchemy
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from service.database import DatabaseService
|
||||||
|
|
||||||
class BaseModel(DeclarativeBase):
|
class BaseModel(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class BaseHelper:
|
||||||
|
def __init__(self, dbs: DatabaseService):
|
||||||
|
self.dbs = dbs
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
if not self.initialized:
|
||||||
|
self.create_session = self.dbs.create_session
|
||||||
|
self.session = self.dbs.create_session()
|
||||||
|
await self.session.__aenter__()
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
await self.session.__aexit__(exc_type, exc, tb)
|
||||||
|
pass
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
def clone_model(model: T) -> T:
|
||||||
|
data_dict = {}
|
||||||
|
for c in sqlalchemy.inspect(model).mapper.column_attrs:
|
||||||
|
if c.key == "id":
|
||||||
|
continue
|
||||||
|
data_dict[c.key] = getattr(model, c.key)
|
||||||
|
return model.__class__(**data_dict)
|
@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import sqlalchemy
|
||||||
|
from api.model.base import BaseHelper, BaseModel
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
||||||
|
|
||||||
|
class BotPersonaModel(BaseModel):
|
||||||
|
__tablename__ = "chat_complete_bot_persona"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||||
|
bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||||||
|
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||||||
|
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||||||
|
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||||||
|
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
|
||||||
|
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
|
||||||
|
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||||||
|
|
||||||
|
class BotPersonaHelper(BaseHelper):
|
||||||
|
async def add(self, obj: BotPersonaModel):
|
||||||
|
self.session.add(obj)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def update(self, obj: BotPersonaModel):
|
||||||
|
obj = await self.session.merge(obj)
|
||||||
|
await self.session.commit()
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def get_list(self):
|
||||||
|
stmt = select(BotPersonaModel).with_only_columns([
|
||||||
|
BotPersonaModel.id,
|
||||||
|
BotPersonaModel.bot_name,
|
||||||
|
BotPersonaModel.bot_avatar,
|
||||||
|
BotPersonaModel.bot_description
|
||||||
|
])
|
||||||
|
return await self.session.scalars(stmt)
|
||||||
|
|
||||||
|
async def find_by_id(self, id: int):
|
||||||
|
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
|
||||||
|
return await self.session.scalar(stmt)
|
||||||
|
|
||||||
|
async def find_by_bot_id(self, bot_id: str):
|
||||||
|
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
|
||||||
|
return await self.session.scalar(stmt)
|
Loading…
Reference in New Issue