You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.6 KiB
Python
87 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import sqlalchemy
|
|
from sqlalchemy import update
|
|
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
|
|
|
from api.model.base import BaseModel
|
|
from api.model.toolkit_ui.conversation import ConversationModel
|
|
from service.database import DatabaseService
|
|
from service.event import EventService
|
|
|
|
class ConversationChunkModel(BaseModel):
|
|
__tablename__ = "chat_complete_conversation_chunk"
|
|
|
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
|
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True)
|
|
message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
|
|
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
|
|
updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True)
|
|
|
|
class ConversationChunkHelper:
|
|
def __init__(self, dbs: DatabaseService):
|
|
self.dbs = dbs
|
|
self.initialized = False
|
|
|
|
async def __aenter__(self):
|
|
if not self.initialized:
|
|
self.create_session = self.dbs.create_session
|
|
self.session = self.dbs.create_session()
|
|
await self.session.__aenter__()
|
|
self.initialized = True
|
|
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await self.session.__aexit__(exc_type, exc, tb)
|
|
pass
|
|
|
|
async def add(self, conversation_id: int, message_data: list, tokens: int):
|
|
async with self.create_session() as session:
|
|
chunk = ConversationChunkModel(
|
|
conversation_id=conversation_id,
|
|
message_data=message_data,
|
|
tokens=tokens,
|
|
updated_at=sqlalchemy.func.current_timestamp()
|
|
)
|
|
session.add(chunk)
|
|
await session.commit()
|
|
await session.refresh(chunk)
|
|
return chunk
|
|
|
|
async def update(self, chunk: ConversationChunkModel):
|
|
chunk.updated_at = sqlalchemy.func.current_timestamp()
|
|
chunk = await self.session.merge(chunk)
|
|
await self.session.commit()
|
|
return chunk
|
|
|
|
async def update_message_log(self, chunk_id: int, message_data: list, tokens: int):
|
|
stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \
|
|
.values(message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp())
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def get_newest_chunk(self, conversation_id: int):
|
|
stmt = sqlalchemy.select(ConversationChunkModel) \
|
|
.where(ConversationChunkModel.conversation_id == conversation_id) \
|
|
.order_by(ConversationChunkModel.id.desc()) \
|
|
.limit(1)
|
|
return await self.session.scalar(stmt)
|
|
|
|
async def remove(self, id: int):
|
|
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def remove_by_conversation_id(self, conversation_id: int):
|
|
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id == conversation_id)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def on_conversation_removed(event):
|
|
if "conversation" in event:
|
|
conversation_info = event["conversation"]
|
|
conversation_id = conversation_info["id"]
|
|
await ConversationChunkHelper(event["dbs"]).remove_by_conversation_id(conversation_id)
|
|
|
|
EventService.create().add_listener("conversation/removed/chatcomplete", on_conversation_removed) |