修正gitignore排除目录问题

master
落雨楓 7 months ago
parent e55c25b9e3
commit 071ab94829

@ -5,12 +5,11 @@ import init.service.mw_api as _ # Init mediawiki api
import init.service.database as _ # Init database import init.service.database as _ # Init database
import init.service.tiktoken as _ # Init tiktoken import init.service.tiktoken as _ # Init tiktoken
import toolbox_ui as _ # Init toolbox ui import init.server.toolbox_ui as _ # Init toolbox ui
import embedding_search as _ # Init embedding search import init.server.embedding_search as _ # Init embedding search
# Auto create database tables # Auto create database tables
from server.model.chat_complete.conversation import ConversationChunkModel as _ from server.model.chat_complete.conversation import ConversationChunkModel as _
from server.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _
from server.model.chat_complete.bot_persona import BotPersonaModel as _ from server.model.chat_complete.bot_persona import BotPersonaModel as _
# Route # Route

@ -2,7 +2,7 @@ from __future__ import annotations
from aiohttp import web from aiohttp import web
from utils.server import register_server_module from utils.server import register_server_module
from lib.config import Config from libs.config import Config
from service.mediawiki_api import MediaWikiApi from service.mediawiki_api import MediaWikiApi

@ -9,7 +9,7 @@ async def init_tiktoken(app: web.Application):
print("Tiktoken model loaded.") print("Tiktoken model loaded.")
async def init(app: web.Application): def init(app: web.Application):
app.on_startup.append(init_tiktoken) app.on_startup.append(init_tiktoken)

@ -0,0 +1,95 @@
from __future__ import annotations
import os
import os.path
import toml
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
config_observer = Observer()
class Config:
file_name: str
values: dict = {}
@staticmethod
def load_config(file):
global config_observer
Config.file_name = file
Config.reload()
if config_observer.is_alive():
config_observer.stop()
config_observer.join()
config_observer.schedule(Config.ConfigEventHandler(), file)
config_observer.start()
@staticmethod
def reload():
with open(Config.file_name, "r", encoding="utf-8") as f:
Config.values = toml.load(f)
@staticmethod
def get(key: str, default=None, type=None, empty_is_none=False):
key_path = key.split(".")
value = Config.values
for k in key_path:
if k in value:
value = value[k]
else:
return default
if empty_is_none and value == "":
return None
if type == bool:
if isinstance(value, bool):
return value
elif isinstance(value, int) or isinstance(value, float):
return value != 0
else:
return str(value).lower() in ("yes", "true", "1")
elif type == int:
return int(value)
elif type == float:
return float(value)
elif type == str:
return str(value)
elif type == list:
if not isinstance(value, list):
return []
elif type == dict:
if not isinstance(value, dict):
return {}
else:
return value
@staticmethod
def set(key: str, value):
key_path = key.split(".")
obj = Config.values
for k in key_path[:-1]:
if k not in obj:
obj[k] = {}
obj = obj[k]
obj[key_path[-1]] = value
@staticmethod
def unload():
global config_observer
if config_observer.is_alive():
config_observer.stop()
config_observer.join()
class ConfigEventHandler(FileSystemEventHandler):
def on_modified(self, event):
if os.path.basename(event.src_path) == os.path.basename(Config.file_name):
print("Config file changed, reloading...")
Config.reload()

@ -0,0 +1,165 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
import atexit
from functools import wraps
import random
import sys
import traceback
from typing import Callable, Coroutine, Optional, TypedDict
class TimerInfo(TypedDict):
id: int
callback: Callable
interval: float
next_time: float
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.timer_map: dict[int, TimerInfo] = {}
self.loop = loop
self.running = True
self.should_refresh_task = False
self.next_timer_time: Optional[float] = None
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
self.timer_task = loop.create_task(self._run_timer())
atexit.register(self.end_task)
async def end(self):
if self.running:
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
await self.timer_task
def end_task(self):
if self.running and not self.loop.is_closed():
self.loop.run_until_complete(self.end())
async def _wrap_task(self, task: Task):
try:
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
finally:
self.should_refresh_task = True
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def add_timer(self, callback: Callable, interval: float) -> int:
id = random.randint(0, 1000000000)
while id in self.timer_map:
id = random.randint(0, 1000000000)
now = self.loop.time()
next_time = now + interval
self.timer_map[id] = {
"id": id,
"callback": callback,
"interval": interval,
"next_time": next_time
}
if self.next_timer_time is None or next_time < self.next_timer_time:
self.next_timer_time = next_time
return id
def remove_timer(self, id: int):
if id in self.timer_map:
del self.timer_map[id]
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
coroutine = f(*args, **kwargs)
self.add_task(coroutine)
return decorated_function
async def _finish_task(self, task: Task):
try:
if not task.done():
task.cancel()
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
if self.should_refresh_task:
should_remove = []
for task in self.task_list:
if task.done():
await self._finish_task(task)
should_remove.append(task)
for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)
async def _run_timer(self):
while self.running:
now = self.loop.time()
if self.next_timer_time is not None and now >= self.next_timer_time:
self.next_timer_time = None
for timer in self.timer_map.values():
if now >= timer["next_time"]:
timer["next_time"] = now + timer["interval"]
try:
result = timer["callback"]()
self.add_task(result)
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
self.add_task(handler_ret)
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
self.next_timer_time = timer["next_time"]
await asyncio.sleep(0.1)

@ -1,6 +1,6 @@
import sys import sys
import traceback import traceback
from lib.config import Config from libs.config import Config
from utils.server import get_server_modules from utils.server import get_server_modules
Config.load_config("config.toml") Config.load_config("config.toml")
@ -8,9 +8,7 @@ Config.load_config("config.toml")
from utils.local import loop, noawait from utils.local import loop, noawait
from aiohttp import web from aiohttp import web
import utils.local as local import utils.local as local
import server.route as api_route
import utils.web import utils.web
from service.mediawiki_api import MediaWikiApi
async def index(request: web.Request): async def index(request: web.Request):
@ -53,8 +51,9 @@ async def error_handler(request, handler):
) )
async def stop_noawait_pool(app: web.Application): async def on_shutdown(app: web.Application):
await noawait.end() await noawait.end()
Config.unload()
if __name__ == "__main__": if __name__ == "__main__":
local.debug = Config.get("server.debug", False, bool) local.debug = Config.get("server.debug", False, bool)
@ -79,15 +78,14 @@ if __name__ == "__main__":
# Initialize server modules # Initialize server modules
server_modules = get_server_modules() server_modules = get_server_modules()
print(server_modules)
for server_module in server_modules: for server_module in server_modules:
print("Loading server module: " + server_module["name"])
server_module["init"](app) server_module["init"](app)
app.on_shutdown.append(stop_noawait_pool) app.on_shutdown.append(on_shutdown)
app.router.add_route("*", "/", index) app.router.add_route("*", "/", index)
api_route.init(app)
server_port = Config.get("port", 8144, int) server_port = Config.get("port", 8144, int)

@ -1,4 +1,9 @@
import sys import sys
import pathlib import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent)) root_path = pathlib.Path(__file__).parent.parent
sys.path.append(".")
from libs.config import Config
Config.load_config(str(root_path) + "/config.toml")

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import time import time
import traceback import traceback
from libs.config import Config
from server.controller.task.ChatCompleteTask import ChatCompleteTask from server.controller.task.ChatCompleteTask import ChatCompleteTask
from server.model.base import clone_model from server.model.base import clone_model
from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.chat_complete.bot_persona import BotPersonaHelper
@ -9,9 +10,9 @@ from server.model.toolbox_ui.conversation import ConversationHelper
from utils.local import noawait from utils.local import noawait
from aiohttp import web from aiohttp import web
from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
import utils.web import utils.web
@ -63,6 +64,7 @@ class ChatComplete:
return await utils.web.api_response(1, conversation_chunk_list, request=request) return await utils.web.api_response(1, conversation_chunk_list, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_conversation_chunk(request: web.Request): async def get_conversation_chunk(request: web.Request):
@ -111,6 +113,7 @@ class ChatComplete:
return await utils.web.api_response(1, chunk_dict, request=request) return await utils.web.api_response(1, chunk_dict, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def fork_conversation(request: web.Request): async def fork_conversation(request: web.Request):
@ -223,6 +226,7 @@ class ChatComplete:
"conversation_id": new_conversation.id, "conversation_id": new_conversation.id,
}, request=request) }, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_tokens(request: web.Request): async def get_tokens(request: web.Request):
@ -240,14 +244,19 @@ class ChatComplete:
return await utils.web.api_response(1, {"tokens": tokens}, request=request) return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_point_cost(request: web.Request): async def get_point_usage(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"question": { "question": {
"type": str, "type": str,
"required": True, "required": True,
}, },
"bot_id": {
"type": str,
"required": True,
},
"extract_limit": { "extract_limit": {
"type": int, "type": int,
"required": False, "required": False,
@ -256,39 +265,41 @@ class ChatComplete:
}) })
user_id = request.get("user") user_id = request.get("user")
caller = request.get("caller")
question = params.get("question") question = params.get("question")
extract_limit = params.get("extract_limit") bot_id = params.get("bot_id")
tiktoken = await TikTokenService.create() tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create() db = await DatabaseService.create(request.app)
tokens = await tiktoken.get_tokens(question) tokens = await tiktoken.get_tokens(question)
try: estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
res = await mwapi.ai_toolbox_get_point_cost(user_id, "chatcomplete", tokens, extract_limit)
predict_tokens = tokens + estimated_extract_tokens_per_doc
async with BotPersonaHelper(db) as bot_persona_helper:
persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
if persona_info is None:
return await utils.web.api_response(-1, error={
"code": "bot-not-found",
"message": "Bot not found."
}, http_status=404, request=request)
point_usage = calculate_point_usage(predict_tokens, persona_info.cost_fixed, persona_info.cost_fixed_tokens, persona_info.cost_per_token)
return await utils.web.api_response(1, { return await utils.web.api_response(1, {
"point_cost": res["point_cost"], "point_usage": point_usage,
"tokens": tokens, "tokens": tokens,
"predict_tokens": predict_tokens,
}, request=request) }, request=request)
except Exception as e:
err_msg = f"Error while get chat complete point cost: {e}"
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_persona_list(request: web.Request): async def get_persona_list(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"category_id": {
"type": int,
"required": False,
},
"page": { "page": {
"type": int, "type": int,
"required": False, "required": False,
@ -296,13 +307,12 @@ class ChatComplete:
} }
}) })
category_id = params.get("category_id")
page = params.get("page") page = params.get("page")
db = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
async with BotPersonaHelper(db) as bot_persona_helper: async with BotPersonaHelper(db) as bot_persona_helper:
persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id) persona_list = await bot_persona_helper.get_list(page=page)
page_count = await bot_persona_helper.get_page_count(category_id=category_id) page_count = await bot_persona_helper.get_page_count()
persona_data_list = [] persona_data_list = []
for persona in persona_list: for persona in persona_list:
@ -312,7 +322,9 @@ class ChatComplete:
"bot_name": persona.bot_name, "bot_name": persona.bot_name,
"bot_avatar": persona.bot_avatar, "bot_avatar": persona.bot_avatar,
"bot_description": persona.bot_description, "bot_description": persona.bot_description,
"updated_at": persona.updated_at, "model_id": persona.model_id,
"model_name": persona.model_name,
"cost_fixed": persona.cost_fixed
}) })
return await utils.web.api_response(1, { return await utils.web.api_response(1, {
@ -320,6 +332,7 @@ class ChatComplete:
"page_count": page_count, "page_count": page_count,
}, request=request) }, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_persona_info(request: web.Request): async def get_persona_info(request: web.Request):
@ -354,6 +367,7 @@ class ChatComplete:
return await utils.web.api_response(1, persona_info_res, request=request) return await utils.web.api_response(1, persona_info_res, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def start_chat_complete(request: web.Request): async def start_chat_complete(request: web.Request):
@ -372,7 +386,7 @@ class ChatComplete:
}, },
"bot_id": { "bot_id": {
"type": str, "type": str,
"required": False, "required": True,
}, },
"extract_limit": { "extract_limit": {
"type": int, "type": int,
@ -453,6 +467,7 @@ class ChatComplete:
"message": err_msg "message": err_msg
}, http_status=500, request=request) }, http_status=500, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def chat_complete_stream(request: web.Request): async def chat_complete_stream(request: web.Request):
@ -541,7 +556,7 @@ class ChatComplete:
try: try:
ignored_keys = ["message"] ignored_keys = ["message"]
response_result = { response_result = {
"point_cost": task.point_cost, "point_usage": task.point_usage,
} }
for k, v in result.items(): for k, v in result.items():
if k not in ignored_keys: if k not in ignored_keys:

@ -1,6 +1,7 @@
import sys import sys
import traceback import traceback
from aiohttp import web from aiohttp import web
from libs.config import Config
from server.model.embedding_search.title_collection import TitleCollectionHelper from server.model.embedding_search.title_collection import TitleCollectionHelper
from server.model.embedding_search.title_index import TitleIndexHelper from server.model.embedding_search.title_index import TitleIndexHelper
from service.database import DatabaseService from service.database import DatabaseService
@ -53,10 +54,10 @@ class EmbeddingSearch:
page_current += 1 page_current += 1
async with EmbeddingSearchService(db, one_title) as embedding_search: async with EmbeddingSearchService(db, one_title) as embedding_search:
if await embedding_search.should_update_page_index(): if await embedding_search.should_update_page_index():
if request.get("caller") == "user": # if request.get("caller") == "user":
user_id = request.get("user") # user_id = request.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") # usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id") # transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index() await embedding_search.prepare_update_index()
@ -144,10 +145,10 @@ class EmbeddingSearch:
async with EmbeddingSearchService(db, page_title) as embedding_search: async with EmbeddingSearchService(db, page_title) as embedding_search:
if await embedding_search.should_update_page_index(): if await embedding_search.should_update_page_index():
if request.get("caller") == "user": # if request.get("caller") == "user":
user_id = request.get("user") # user_id = request.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") # usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id") # transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index() await embedding_search.prepare_update_index()
@ -236,7 +237,7 @@ class EmbeddingSearch:
"distancelimit": { "distancelimit": {
"required": False, "required": False,
"type": float, "type": float,
"default": 0.6 "default": None
}, },
}) })

@ -2,12 +2,15 @@ from __future__ import annotations
import sys import sys
import time import time
import traceback import traceback
from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper
from utils.local import noawait from utils.local import noawait
from typing import Optional, Callable, Union from typing import Optional, Callable, Union
from service.chat_complete import ( from service.chat_complete import (
ChatCompleteService, ChatCompleteService,
ChatCompleteServicePrepareResponse, ChatCompleteServicePrepareResponse,
ChatCompleteServiceResponse, ChatCompleteServiceResponse,
calculate_point_usage,
) )
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs from service.embedding_search import EmbeddingSearchArgs
@ -41,7 +44,7 @@ class ChatCompleteTask:
self.is_system = is_system self.is_system = is_system
self.transatcion_id: Optional[str] = None self.transatcion_id: Optional[str] = None
self.point_cost: int = 0 self.point_usage: int = 0
self.is_finished = False self.is_finished = False
self.finished_time: Optional[float] = None self.finished_time: Optional[float] = None
@ -51,9 +54,9 @@ class ChatCompleteTask:
async def init( async def init(
self, self,
question: str, question: str,
bot_id: str,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
edit_message_id: Optional[str] = None, edit_message_id: Optional[str] = None,
bot_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None, embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse: ) -> ChatCompleteServicePrepareResponse:
self.tiktoken = await TikTokenService.create() self.tiktoken = await TikTokenService.create()
@ -69,19 +72,31 @@ class ChatCompleteTask:
extract_limit = embedding_search["limit"] or 10 extract_limit = embedding_search["limit"] or 10
estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
predict_tokens = extract_limit + estimated_extract_tokens_per_doc
async with BotPersonaHelper(self.dbs) as bot_persona_helper:
bot_persona = await bot_persona_helper.find_by_bot_id(bot_id)
self.point_usage: int = calculate_point_usage(predict_tokens, bot_persona.cost_fixed,
bot_persona.cost_fixed_tokens, bot_persona.cost_per_token)
self.transatcion_id: Optional[str] = None self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system: if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction( usage_res = await self.mwapi.ai_toolbox_start_transaction(
self.user_id, "chatcomplete", question_tokens, extract_limit self.user_id, "chatcomplete",
bot_id=bot_id,
tokens=predict_tokens,
point_usage=self.point_usage
) )
self.transatcion_id = usage_res["transaction_id"] self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
res = await self.chat_complete.prepare_chat_complete( res = await self.chat_complete.prepare_chat_complete(
question, question,
conversation_id=conversation_id, conversation_id=conversation_id,
user_id=self.user_id, user_id=self.user_id,
question_tokens=question_tokens,
edit_message_id=edit_message_id, edit_message_id=edit_message_id,
bot_id=bot_id, bot_id=bot_id,
embedding_search=embedding_search, embedding_search=embedding_search,
@ -136,13 +151,13 @@ class ChatCompleteTask:
try: try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message) chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost) await self.chat_complete.set_latest_point_usage(self.point_usage)
self.result = chat_res self.result = chat_res
if self.transatcion_id: if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction( await self.mwapi.ai_toolbox_end_transaction(
self.transatcion_id, chat_res["total_tokens"] self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分
) )
await self._on_finished() await self._on_finished()

@ -5,10 +5,9 @@ import sqlalchemy
from server.model.base import BaseHelper, BaseModel from server.model.base import BaseHelper, BaseModel
import sqlalchemy import sqlalchemy
from sqlalchemy import select, update from sqlalchemy import select
from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped from sqlalchemy.orm import mapped_column, load_only, Mapped
from server.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
from service.database import DatabaseService from service.database import DatabaseService
@ -22,17 +21,16 @@ class BotPersonaModel(BaseModel):
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
category_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(
BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"
),
index=True,
)
api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
model_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True, index=True)
model_name: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) message_log: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) cost_fixed: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True)
cost_fixed_tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
cost_per_token: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True)
order: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
class BotPersonaHelper(BaseHelper): class BotPersonaHelper(BaseHelper):
@ -50,8 +48,7 @@ class BotPersonaHelper(BaseHelper):
async def get_list( async def get_list(
self, self,
page: Optional[int] = 1, page: Optional[int] = 1,
page_size: Optional[int] = 20, page_size: Optional[int] = 20
category_id: Optional[int] = None,
): ):
offset_index = (page - 1) * page_size offset_index = (page - 1) * page_size
@ -64,25 +61,22 @@ class BotPersonaHelper(BaseHelper):
BotPersonaModel.bot_name, BotPersonaModel.bot_name,
BotPersonaModel.bot_avatar, BotPersonaModel.bot_avatar,
BotPersonaModel.bot_description, BotPersonaModel.bot_description,
BotPersonaModel.updated_at, BotPersonaModel.model_id,
BotPersonaModel.model_name,
BotPersonaModel.cost_fixed,
BotPersonaModel.order,
) )
) )
.order_by(BotPersonaModel.updated_at.desc()) .order_by(BotPersonaModel.order.desc())
.offset(offset_index) .offset(offset_index)
.limit(page_size) .limit(page_size)
) )
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
return await self.session.scalars(stmt) return await self.session.scalars(stmt)
async def get_page_count(self, page_size=50, category_id: Optional[int] = None): async def get_page_count(self, page_size=50):
stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel) stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
item_count = await self.session.scalar(stmt) item_count = await self.session.scalar(stmt)
if item_count is None: if item_count is None:
item_count = 0 item_count = 0

@ -1,43 +0,0 @@
from __future__ import annotations
import sqlalchemy
from server.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
class BotPersonaCategoryModel(BaseModel):
__tablename__ = "chat_complete_bot_persona_category"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
font_icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
color: Mapped[str] = mapped_column(sqlalchemy.String(10), nullable=True)
order: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaCategoryModel):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
async def update(self, obj: BotPersonaCategoryModel):
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def remove(self, id: int):
stmt = sqlalchemy.delete(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
await self.session.execute(stmt)
await self.session.commit()
async def get_list(self):
stmt = select(BotPersonaCategoryModel)
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
return await self.session.scalar(stmt)

@ -6,7 +6,7 @@ import asyncpg
from server.model.base import BaseModel from server.model.base import BaseModel
import numpy as np import numpy as np
import sqlalchemy import sqlalchemy
from lib.config import Config from libs.config import Config
from sqlalchemy import Index, select, update, delete, Select from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession

@ -8,7 +8,7 @@ import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
from lib.config import Config from libs.config import Config
from server.model.base import BaseHelper, BaseModel from server.model.base import BaseHelper, BaseModel
from service.database import DatabaseService from service.database import DatabaseService

@ -9,7 +9,7 @@ def register_route(app: web.Application):
web.route('POST', '/chatcomplete/conversation/fork', ChatComplete.fork_conversation), web.route('POST', '/chatcomplete/conversation/fork', ChatComplete.fork_conversation),
web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete), web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost), web.route('POST', '/chatcomplete/get_point_usage', ChatComplete.get_point_usage),
web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list), web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list),
web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info), web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info),
]) ])

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import math
import time import time
import traceback import traceback
from typing import Optional, Tuple, TypedDict from typing import Optional, Tuple, TypedDict
@ -11,7 +12,7 @@ from server.model.chat_complete.conversation import (
import sys import sys
from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel
from lib.config import Config from libs.config import Config
import utils.config, utils.web import utils.config, utils.web
from aiohttp import web from aiohttp import web
@ -72,6 +73,7 @@ class ChatCompleteService:
self.question = "" self.question = ""
self.question_tokens: Optional[int] = None self.question_tokens: Optional[int] = None
self.bot_id: str = "" self.bot_id: str = ""
self.model: Optional[str] = None
self.conversation_id: Optional[int] = None self.conversation_id: Optional[int] = None
self.conversation_start_time: Optional[int] = None self.conversation_start_time: Optional[int] = None
@ -104,11 +106,11 @@ class ChatCompleteService:
async def prepare_chat_complete( async def prepare_chat_complete(
self, self,
question: str, question: str,
bot_id: str,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
user_id: Optional[int] = None, user_id: Optional[int] = None,
question_tokens: Optional[int] = None, question_tokens: Optional[int] = None,
edit_message_id: Optional[str] = None, edit_message_id: Optional[str] = None,
bot_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None, embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse: ) -> ChatCompleteServicePrepareResponse:
if user_id is not None: if user_id is not None:
@ -117,7 +119,7 @@ class ChatCompleteService:
self.user_id = user_id self.user_id = user_id
self.question = question self.question = question
self.conversation_start_time = int(time.time()) self.conversation_start_time = int(time.time())
self.bot_id = bot_id or None self.bot_id = bot_id
self.conversation_info = None self.conversation_info = None
if conversation_id is not None: if conversation_id is not None:
@ -153,14 +155,17 @@ class ChatCompleteService:
if bot_persona is None: if bot_persona is None:
self.bot_id = "default" self.bot_id = "default"
self.model = None
bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id) bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
else: else:
self.model = bot_persona.model_id
self.chat_system_prompt = bot_persona.system_prompt self.chat_system_prompt = bot_persona.system_prompt
default_api = Config.get("chatcomplete.default_api", None, str) default_api = Config.get("chatcomplete.default_api", None, str)
try: try:
self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api) self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api)
except OpenAIApiTypeInvalidException: except OpenAIApiTypeInvalidException:
print(f"Invalid API type: {bot_persona.api_id}", file=sys.stderr)
self.openai_api = OpenAIApi.create(default_api) self.openai_api = OpenAIApi.create(default_api)
self.conversation_chunk = None self.conversation_chunk = None
@ -236,10 +241,11 @@ class ChatCompleteService:
init_message_data = [] init_message_data = []
if bot_persona is not None: if bot_persona is not None:
current_time = int(time.time()) current_time = int(time.time())
for message in bot_persona.message_log: if bot_persona.message_log is not None:
message["id"] = utils.web.generate_uuid() for message in bot_persona.message_log:
message["time"] = current_time message["id"] = utils.web.generate_uuid()
init_message_data.append(message) message["time"] = current_time
init_message_data.append(message)
title_info = self.embedding_search.title_index title_info = self.embedding_search.title_index
self.conversation_info = ConversationModel( self.conversation_info = ConversationModel(
@ -319,11 +325,11 @@ class ChatCompleteService:
# Start chat complete # Start chat complete
if on_message is not None: if on_message is not None:
response = await self.openai_api.chat_complete_stream( response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, message_log, on_message self.question, system_prompt, self.model, message_log, on_message
) )
else: else:
response = await self.openai_api.chat_complete( response = await self.openai_api.chat_complete(
self.question, system_prompt, message_log self.question, system_prompt, self.model, message_log
) )
description = response["message"][0:150] description = response["message"][0:150]
@ -384,7 +390,7 @@ class ChatCompleteService:
delta_data=delta_data, delta_data=delta_data,
) )
async def set_latest_point_cost(self, point_cost: int) -> bool: async def set_latest_point_usage(self, point_usage: int) -> bool:
if self.conversation_chunk is None: if self.conversation_chunk is None:
return False return False
@ -393,7 +399,7 @@ class ChatCompleteService:
for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
if self.conversation_chunk.message_data[i].get("role") == "assistant": if self.conversation_chunk.message_data[i].get("role") == "assistant":
self.conversation_chunk.message_data[i]["point_cost"] = point_cost self.conversation_chunk.message_data[i]["point_usage"] = point_usage
flag_modified(self.conversation_chunk, "message_data") flag_modified(self.conversation_chunk, "message_data")
await self.conversation_chunk_helper.update(self.conversation_chunk) await self.conversation_chunk_helper.update(self.conversation_chunk)
@ -402,6 +408,10 @@ class ChatCompleteService:
async def make_summary(self, message_log_list: list) -> tuple[str, int]: async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = [] chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
api_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_api_id", "default", str)
openai_api = OpenAIApi.create(api_id)
for message_data in message_log_list: for message_data in message_log_list:
if "content" in message_data: if "content" in message_data:
@ -422,8 +432,8 @@ class ChatCompleteService:
"make_summary", "prompt", {"content": chat_log_str} "make_summary", "prompt", {"content": chat_log_str}
) )
response = await self.openai_api.chat_complete( response = await openai_api.chat_complete(
summary_prompt, summary_system_prompt summary_prompt, summary_system_prompt, model_id
) )
return response["message"], response["message_tokens"] return response["message"], response["message_tokens"]
@ -431,6 +441,10 @@ class ChatCompleteService:
async def make_title(self, message_log_list: list) -> tuple[str, int]: async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = [] chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
api_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_api_id", "default", str)
openai_api = OpenAIApi.create(api_id)
for message_data in message_log_list: for message_data in message_log_list:
if "content" in message_data: if "content" in message_data:
@ -451,7 +465,18 @@ class ChatCompleteService:
"make_title", "prompt", {"content": chat_log_str} "make_title", "prompt", {"content": chat_log_str}
) )
response = await self.openai_api.chat_complete( response = await openai_api.chat_complete(
title_prompt, title_system_prompt title_prompt, title_system_prompt, model_id
) )
if response["message"] is None:
print(response)
raise Exception("Title generation failed")
return response["message"][0:250], response["message_tokens"] return response["message"][0:250], response["message_tokens"]
def calculate_point_usage(tokens: int, cost_fixed: float, cost_fixed_tokens: int, cost_per_token: float) -> int:
if tokens <= cost_fixed_tokens:
return cost_fixed
return cost_fixed + math.ceil((tokens - cost_fixed_tokens) * cost_per_token)

@ -4,7 +4,7 @@ from urllib.parse import quote_plus
from aiohttp import web from aiohttp import web
import asyncpg import asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from lib.config import Config from libs.config import Config
def get_dsn(): def get_dsn():
db_conf = Config.get("database") db_conf = Config.get("database")

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Optional, TypedDict from typing import Optional, TypedDict
import sqlalchemy import sqlalchemy
from libs.config import Config
from server.model.embedding_search.title_collection import ( from server.model.embedding_search.title_collection import (
TitleCollectionHelper, TitleCollectionHelper,
TitleCollectionModel, TitleCollectionModel,
@ -10,7 +11,6 @@ from server.model.embedding_search.title_index import TitleIndexHelper, TitleInd
from server.model.embedding_search.page_index import PageIndexHelper from server.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.text_embedding import TextEmbeddingService from service.text_embedding import TextEmbeddingService
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences from utils.wiki import getWikiSentences
@ -43,7 +43,6 @@ class EmbeddingSearchService:
self.tiktoken: TikTokenService = None self.tiktoken: TikTokenService = None
self.mwapi = MediaWikiApi.create() self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.page_id: Optional[int] = None self.page_id: Optional[int] = None
self.collection_id: Optional[int] = None self.collection_id: Optional[int] = None
@ -284,11 +283,14 @@ class EmbeddingSearchService:
query: str, query: str,
limit: int = 10, limit: int = 10,
in_collection: bool = False, in_collection: bool = False,
distance_limit: float = 0.6, distance_limit: Optional[float] = None,
): ):
if limit == 0: if limit == 0:
return [], 0 return [], 0
if distance_limit is None:
distance_limit = Config.get("embedding.default_distance_limit")
if self.page_index is None: if self.page_index is None:
raise Exception("Page index is not initialized") raise Exception("Page index is not initialized")
@ -296,12 +298,15 @@ class EmbeddingSearchService:
query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc) query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"] query_embedding = query_doc[0]["embedding"]
print(query_embedding)
if query_embedding is None: if query_embedding is None:
return [], token_usage return [], token_usage
res = await self.page_index.search_text_embedding( res = await self.page_index.search_text_embedding(
query_embedding, in_collection, limit, self.page_id query_embedding, in_collection, limit, self.page_id
) )
print(res)
if res: if res:
filtered = [] filtered = []
for one in res: for one in res:

@ -4,7 +4,7 @@ import sys
import time import time
from typing import Optional, TypedDict from typing import Optional, TypedDict
import aiohttp import aiohttp
from lib.config import Config from libs.config import Config
class MediaWikiApiException(Exception): class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None: def __init__(self, info: str, code: Optional[str] = None) -> None:
@ -41,10 +41,9 @@ class GetAllPagesResponse(TypedDict):
continue_key: Optional[str] continue_key: Optional[str]
class ChatCompleteGetPointUsageResponse(TypedDict): class ChatCompleteGetPointUsageResponse(TypedDict):
point_cost: int point_usage: int
class ChatCompleteReportUsageResponse(TypedDict): class ChatCompleteReportUsageResponse(TypedDict):
point_cost: int
transaction_id: str transaction_id: str
class MediaWikiApi: class MediaWikiApi:
@ -58,6 +57,7 @@ class MediaWikiApi:
return MediaWikiApi.instance return MediaWikiApi.instance
def __init__(self, api_url: str): def __init__(self, api_url: str):
self.api_url = api_url self.api_url = api_url
self.request_proxy = Config.get("request.proxy", type=str, empty_is_none=True) self.request_proxy = Config.get("request.proxy", type=str, empty_is_none=True)
@ -66,6 +66,7 @@ class MediaWikiApi:
self.login_identity = None self.login_identity = None
self.login_time = 0.0 self.login_time = 0.0
async def get_page_info(self, title: str): async def get_page_info(self, title: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -86,6 +87,7 @@ class MediaWikiApi:
return data["query"]["pages"][0] return data["query"]["pages"][0]
async def parse_page(self, title: str): async def parse_page(self, title: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -105,6 +107,7 @@ class MediaWikiApi:
return data["parse"]["text"] return data["parse"]["text"]
async def get_site_meta(self): async def get_site_meta(self):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -131,6 +134,7 @@ class MediaWikiApi:
return ret return ret
async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse: async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse:
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -163,6 +167,7 @@ class MediaWikiApi:
return ret return ret
async def is_logged_in(self): async def is_logged_in(self):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -178,6 +183,7 @@ class MediaWikiApi:
return data["query"]["userinfo"]["id"] != 0 return data["query"]["userinfo"]["id"] != 0
async def get_token(self, token_type: str): async def get_token(self, token_type: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -194,6 +200,7 @@ class MediaWikiApi:
return data["query"]["tokens"][token_type + "token"] return data["query"]["tokens"][token_type + "token"]
async def robot_login(self, username: str, password: str): async def robot_login(self, username: str, password: str):
if await self.is_logged_in(): if await self.is_logged_in():
self.login_time = time.time() self.login_time = time.time()
@ -225,6 +232,7 @@ class MediaWikiApi:
return True return True
async def refresh_login(self): async def refresh_login(self):
if self.login_identity is None: if self.login_identity is None:
print("刷新MW机器人账号登录状态失败没有保存的用户") print("刷新MW机器人账号登录状态失败没有保存的用户")
@ -233,6 +241,7 @@ class MediaWikiApi:
print("刷新MW机器人账号登录状态") print("刷新MW机器人账号登录状态")
return await self.robot_login(self.login_identity["username"], self.login_identity["password"]) return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
async def search_title(self, keyword: str) -> list[str]: async def search_title(self, keyword: str) -> list[str]:
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {
@ -245,6 +254,7 @@ class MediaWikiApi:
data = await resp.json() data = await resp.json()
return data[1] return data[1]
async def ai_toolbox_get_user_info(self, user_id: int): async def ai_toolbox_get_user_info(self, user_id: int):
await self.refresh_login() await self.refresh_login()
@ -266,32 +276,9 @@ class MediaWikiApi:
return data["aitoolboxbot"]["userinfo"] return data["aitoolboxbot"]["userinfo"]
async def ai_toolbox_get_point_cost(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteGetPointUsageResponse:
await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
post_data = {
"action": "aitoolboxbot",
"method": "reportusage",
"step": "check",
"userid": int(user_id) if user_id is not None else None,
"useraction": user_action,
"tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None,
"format": "json",
"formatversion": "2",
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
point_cost = int(data["aitoolboxbot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteGetPointUsageResponse(point_cost=point_cost)
async def ai_toolbox_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse: async def ai_toolbox_start_transaction(self, user_id: int, user_action: str, bot_id: Optional[str] = None,
tokens: Optional[int] = None, point_usage: Optional[int] = None) -> ChatCompleteReportUsageResponse:
await self.refresh_login() await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -301,8 +288,9 @@ class MediaWikiApi:
"step": "start", "step": "start",
"userid": int(user_id) if user_id is not None else None, "userid": int(user_id) if user_id is not None else None,
"useraction": user_action, "useraction": user_action,
"botid": bot_id,
"tokens": int(tokens) if tokens is not None else None, "tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None, "pointusage": int(point_usage) if point_usage is not None else 0,
"format": "json", "format": "json",
"formatversion": "2", "formatversion": "2",
} }
@ -316,11 +304,10 @@ class MediaWikiApi:
else: else:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
point_cost = int(data["aitoolboxbot"]["reportusage"]["pointcost"] or 0) return ChatCompleteReportUsageResponse(transaction_id=data["aitoolboxbot"]["reportusage"]["transactionid"])
return ChatCompleteReportUsageResponse(point_cost=point_cost,
transaction_id=data["aitoolboxbot"]["reportusage"]["transactionid"])
async def ai_toolbox_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
async def ai_toolbox_end_transaction(self, transaction_id: str, point_usage: Optional[int] = None, tokens: Optional[int] = None):
await self.refresh_login() await self.refresh_login()
try: try:
@ -330,7 +317,7 @@ class MediaWikiApi:
"method": "reportusage", "method": "reportusage",
"step": "end", "step": "end",
"transactionid": transaction_id, "transactionid": transaction_id,
"tokens": tokens, "pointusage": int(point_usage) if point_usage is not None else 0,
"format": "json", "format": "json",
"formatversion": "2", "formatversion": "2",
} }
@ -345,6 +332,7 @@ class MediaWikiApi:
except Exception as e: except Exception as e:
print(e, file=sys.stderr) print(e, file=sys.stderr)
async def ai_toolbox_cancel_transaction(self, transaction_id: str, error: Optional[str] = None): async def ai_toolbox_cancel_transaction(self, transaction_id: str, error: Optional[str] = None):
await self.refresh_login() await self.refresh_login()

@ -4,7 +4,7 @@ import json
from typing import Callable, Optional, TypedDict from typing import Callable, Optional, TypedDict
import aiohttp import aiohttp
from lib.config import Config from libs.config import Config
import numpy as np import numpy as np
from aiohttp_sse_client2 import client as sse_client from aiohttp_sse_client2 import client as sse_client
@ -65,11 +65,18 @@ class OpenAIApi:
def get_url(self, method: str): def get_url(self, method: str):
if self.api_type == "azure": if self.api_type == "azure":
deployments = Config.get(f"chatcomplete.{self.api_id}.deployments")
if method == "chat/completions": if method == "chat/completions":
return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_chatcomplete")
if deployment is None:
raise Exception("deployment for chatcomplete is not set")
return self.api_url + "/openai/deployments/" + deployment + "/" + method
elif method == "embeddings": elif method == "embeddings":
return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_embedding")
if deployment is None:
raise Exception("deployment for embedding is not set")
return self.api_url + "/openai/deployments/" + deployment + "/" + method
else: else:
return self.api_url + "/v1/" + method return self.api_url + "/v1/" + method
@ -175,7 +182,7 @@ class OpenAIApi:
return messageList return messageList
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None): async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation) messageList = await self.make_message_list(question, system_prompt, conversation)
url = self.get_url("chat/completions") url = self.get_url("chat/completions")
@ -189,7 +196,7 @@ class OpenAIApi:
if self.api_type == "azure": if self.api_type == "azure":
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
else: else:
post_data["model"] = "gpt-3.5-turbo" post_data["model"] = model
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}
@ -218,10 +225,13 @@ class OpenAIApi:
message_tokens=message_tokens, message_tokens=message_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
finish_reason=finish_reason) finish_reason=finish_reason)
else:
print(data)
raise Exception("Invalid response from chat complete api")
return None return None
async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None): async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
tiktoken = await TikTokenService.create() tiktoken = await TikTokenService.create()
messageList = await self.make_message_list(question, system_prompt, conversation) messageList = await self.make_message_list(question, system_prompt, conversation)
@ -247,7 +257,7 @@ class OpenAIApi:
if self.api_type == "azure": if self.api_type == "azure":
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
else: else:
post_data["model"] = "gpt-3.5-turbo" post_data["model"] = model
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy from copy import deepcopy
import time import time
from lib.config import Config from libs.config import Config
import asyncio import asyncio
import random import random
import threading import threading
@ -22,8 +22,8 @@ class Text2VecEmbeddingQueueTaskInfo(TypedDict):
class Text2VecEmbeddingQueue: class Text2VecEmbeddingQueue:
def __init__(self, model: str) -> None: def __init__(self, model: str) -> None:
self.model_name = model self.model_name = model
self.embedding_model: SentenceModel = None
self.embedding_model = SentenceModel(self.model_name)
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {} self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = [] self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock() self.lock = threading.Lock()
@ -32,6 +32,10 @@ class Text2VecEmbeddingQueue:
self.running = False self.running = False
def post_init(self):
self.embedding_model = SentenceModel(self.model_name)
async def get_embeddings(self, text: str): async def get_embeddings(self, text: str):
task_id = random.randint(0, 1000000000) task_id = random.randint(0, 1000000000)
with self.lock: with self.lock:
@ -54,6 +58,7 @@ class Text2VecEmbeddingQueue:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
def pop_task(self, task_id): def pop_task(self, task_id):
with self.lock: with self.lock:
if task_id in self.task_map: if task_id in self.task_map:
@ -64,6 +69,7 @@ class Text2VecEmbeddingQueue:
return None return None
def run(self): def run(self):
running = True running = True
last_task_time = None last_task_time = None
@ -88,15 +94,22 @@ class Text2VecEmbeddingQueue:
else: else:
time.sleep(0.01) time.sleep(0.01)
def start_queue(self): def start_queue(self):
if not self.running: if not self.running:
self.running = True self.running = True
self.thread = threading.Thread(target=self.run) self.thread = threading.Thread(target=self.run)
self.thread.start() self.thread.start()
class TextEmbeddingService: class TextEmbeddingService:
instance = None instance = None
def __init__(self):
self.tiktoken: TikTokenService = None
self.text2vec_queue: Text2VecEmbeddingQueue = None
self.openai_api: OpenAIApi = None
@staticmethod @staticmethod
async def create() -> TextEmbeddingService: async def create() -> TextEmbeddingService:
if TextEmbeddingService.instance is None: if TextEmbeddingService.instance is None:
@ -111,21 +124,27 @@ class TextEmbeddingService:
if self.embedding_type == "text2vec": if self.embedding_type == "text2vec":
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese") embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model) self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model)
elif self.embedding_type == "openai": elif self.embedding_type == "openai":
self.openai_api: OpenAIApi = await OpenAIApi.create() api_id = Config.get("embedding.api_id")
self.openai_api: OpenAIApi = await OpenAIApi.create(api_id)
await loop.run_in_executor(None, self.embedding_queue.init) await loop.run_in_executor(None, self.text2vec_queue.post_init)
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
total_token_usage = 0
for index, doc in enumerate(doc_list): for index, doc in enumerate(doc_list):
text = doc["text"] text = doc["text"]
embedding = await self.embedding_queue.get_embeddings(text) embedding = await self.text2vec_queue.get_embeddings(text)
doc["embedding"] = embedding doc["embedding"] = embedding
total_token_usage += await self.tiktoken.get_tokens(text)
if on_index_progress is not None: if on_index_progress is not None:
await on_index_progress(index, len(doc_list)) await on_index_progress(index, len(doc_list))
return (doc_list, total_token_usage)
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
res_doc_list = deepcopy(doc_list) res_doc_list = deepcopy(doc_list)

@ -4,7 +4,7 @@ import pathlib
root_path = pathlib.Path(__file__).parent.parent root_path = pathlib.Path(__file__).parent.parent
sys.path.append(".") sys.path.append(".")
from lib.config import Config from libs.config import Config
Config.load_config(str(root_path) + "/config.toml") Config.load_config(str(root_path) + "/config.toml")
Config.set("server.debug", True) Config.set("server.debug", True)

@ -16,7 +16,7 @@ async def main():
try: try:
prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, { prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, {
"distance_limit": 0.6, "distance_limit": 20,
"limit": 10 "limit": 10
}) })

@ -14,14 +14,14 @@ async def main():
print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True) print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True)
print("") print("")
await embedding_search.update_page_index(on_index_progress) #await embedding_search.update_page_index(on_index_progress)
print("") print("")
while True: while True:
query = input("请输入要搜索的问题 (.exit 退出)") query = input("请输入要搜索的问题 (.exit 退出)")
if query == ".exit": if query == ".exit":
break break
res, token_usage = await embedding_search.search(query, 5) res, token_usage = await embedding_search.search(query, 5, False, 0.1)
total_length = 0 total_length = 0
if res: if res:
for one in res: for one in res:

@ -1,5 +1,5 @@
import time import time
from lib.config import Config from libs.config import Config
def get_prompt(name: str, type: str, params: dict = {}): def get_prompt(name: str, type: str, params: dict = {}):
sys_params = { sys_params = {

@ -1,5 +1,5 @@
import asyncio import asyncio
from lib.noawait import NoAwaitPool from libs.noawait import NoAwaitPool
debug = False debug = False
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()

@ -5,7 +5,7 @@ from typing import Any, Optional, Dict
from aiohttp import web from aiohttp import web
import jwt import jwt
import uuid import uuid
from lib.config import Config from libs.config import Config
ParamRule = Dict[str, Any] ParamRule = Dict[str, Any]
@ -15,7 +15,7 @@ class ParamInvalidException(web.HTTPBadRequest):
self.param_list = param_list self.param_list = param_list
self.rules = rules self.rules = rules
param_list_str = "'" + ("', '".join(param_list)) + "'" param_list_str = "'" + ("', '".join(param_list)) + "'"
super().__init__(f"Param invalid: {param_list_str}") super().__init__(reason=f"Param invalid: {param_list_str}")
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None): async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
params: dict[str, Any] = {} params: dict[str, Any] = {}

Loading…
Cancel
Save