修正gitignore排除目录问题

master
落雨楓 7 months ago
parent e55c25b9e3
commit 071ab94829

@ -5,12 +5,11 @@ import init.service.mw_api as _ # Init mediawiki api
import init.service.database as _ # Init database
import init.service.tiktoken as _ # Init tiktoken
import toolbox_ui as _ # Init toolbox ui
import embedding_search as _ # Init embedding search
import init.server.toolbox_ui as _ # Init toolbox ui
import init.server.embedding_search as _ # Init embedding search
# Auto create database tables
from server.model.chat_complete.conversation import ConversationChunkModel as _
from server.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _
from server.model.chat_complete.bot_persona import BotPersonaModel as _
# Route

@ -2,7 +2,7 @@ from __future__ import annotations
from aiohttp import web
from utils.server import register_server_module
from lib.config import Config
from libs.config import Config
from service.mediawiki_api import MediaWikiApi

@ -9,7 +9,7 @@ async def init_tiktoken(app: web.Application):
print("Tiktoken model loaded.")
async def init(app: web.Application):
def init(app: web.Application):
app.on_startup.append(init_tiktoken)

@ -0,0 +1,95 @@
from __future__ import annotations
import os
import os.path
import toml
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
config_observer = Observer()
class Config:
file_name: str
values: dict = {}
@staticmethod
def load_config(file):
global config_observer
Config.file_name = file
Config.reload()
if config_observer.is_alive():
config_observer.stop()
config_observer.join()
config_observer.schedule(Config.ConfigEventHandler(), file)
config_observer.start()
@staticmethod
def reload():
with open(Config.file_name, "r", encoding="utf-8") as f:
Config.values = toml.load(f)
@staticmethod
def get(key: str, default=None, type=None, empty_is_none=False):
key_path = key.split(".")
value = Config.values
for k in key_path:
if k in value:
value = value[k]
else:
return default
if empty_is_none and value == "":
return None
if type == bool:
if isinstance(value, bool):
return value
elif isinstance(value, int) or isinstance(value, float):
return value != 0
else:
return str(value).lower() in ("yes", "true", "1")
elif type == int:
return int(value)
elif type == float:
return float(value)
elif type == str:
return str(value)
elif type == list:
if not isinstance(value, list):
return []
elif type == dict:
if not isinstance(value, dict):
return {}
else:
return value
@staticmethod
def set(key: str, value):
key_path = key.split(".")
obj = Config.values
for k in key_path[:-1]:
if k not in obj:
obj[k] = {}
obj = obj[k]
obj[key_path[-1]] = value
@staticmethod
def unload():
global config_observer
if config_observer.is_alive():
config_observer.stop()
config_observer.join()
class ConfigEventHandler(FileSystemEventHandler):
def on_modified(self, event):
if os.path.basename(event.src_path) == os.path.basename(Config.file_name):
print("Config file changed, reloading...")
Config.reload()

@ -0,0 +1,165 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
import atexit
from functools import wraps
import random
import sys
import traceback
from typing import Callable, Coroutine, Optional, TypedDict
class TimerInfo(TypedDict):
id: int
callback: Callable
interval: float
next_time: float
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.timer_map: dict[int, TimerInfo] = {}
self.loop = loop
self.running = True
self.should_refresh_task = False
self.next_timer_time: Optional[float] = None
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
self.timer_task = loop.create_task(self._run_timer())
atexit.register(self.end_task)
async def end(self):
if self.running:
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
await self.timer_task
def end_task(self):
if self.running and not self.loop.is_closed():
self.loop.run_until_complete(self.end())
async def _wrap_task(self, task: Task):
try:
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
finally:
self.should_refresh_task = True
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def add_timer(self, callback: Callable, interval: float) -> int:
id = random.randint(0, 1000000000)
while id in self.timer_map:
id = random.randint(0, 1000000000)
now = self.loop.time()
next_time = now + interval
self.timer_map[id] = {
"id": id,
"callback": callback,
"interval": interval,
"next_time": next_time
}
if self.next_timer_time is None or next_time < self.next_timer_time:
self.next_timer_time = next_time
return id
def remove_timer(self, id: int):
if id in self.timer_map:
del self.timer_map[id]
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
coroutine = f(*args, **kwargs)
self.add_task(coroutine)
return decorated_function
async def _finish_task(self, task: Task):
try:
if not task.done():
task.cancel()
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
if self.should_refresh_task:
should_remove = []
for task in self.task_list:
if task.done():
await self._finish_task(task)
should_remove.append(task)
for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)
async def _run_timer(self):
while self.running:
now = self.loop.time()
if self.next_timer_time is not None and now >= self.next_timer_time:
self.next_timer_time = None
for timer in self.timer_map.values():
if now >= timer["next_time"]:
timer["next_time"] = now + timer["interval"]
try:
result = timer["callback"]()
self.add_task(result)
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
self.add_task(handler_ret)
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
self.next_timer_time = timer["next_time"]
await asyncio.sleep(0.1)

@ -1,6 +1,6 @@
import sys
import traceback
from lib.config import Config
from libs.config import Config
from utils.server import get_server_modules
Config.load_config("config.toml")
@ -8,9 +8,7 @@ Config.load_config("config.toml")
from utils.local import loop, noawait
from aiohttp import web
import utils.local as local
import server.route as api_route
import utils.web
from service.mediawiki_api import MediaWikiApi
async def index(request: web.Request):
@ -53,8 +51,9 @@ async def error_handler(request, handler):
)
async def stop_noawait_pool(app: web.Application):
async def on_shutdown(app: web.Application):
await noawait.end()
Config.unload()
if __name__ == "__main__":
local.debug = Config.get("server.debug", False, bool)
@ -79,15 +78,14 @@ if __name__ == "__main__":
# Initialize server modules
server_modules = get_server_modules()
print(server_modules)
for server_module in server_modules:
print("Loading server module: " + server_module["name"])
server_module["init"](app)
app.on_shutdown.append(stop_noawait_pool)
app.on_shutdown.append(on_shutdown)
app.router.add_route("*", "/", index)
api_route.init(app)
server_port = Config.get("port", 8144, int)

@ -1,4 +1,9 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
root_path = pathlib.Path(__file__).parent.parent
sys.path.append(".")
from libs.config import Config
Config.load_config(str(root_path) + "/config.toml")

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import time
import traceback
from libs.config import Config
from server.controller.task.ChatCompleteTask import ChatCompleteTask
from server.model.base import clone_model
from server.model.chat_complete.bot_persona import BotPersonaHelper
@ -9,9 +10,9 @@ from server.model.toolbox_ui.conversation import ConversationHelper
from utils.local import noawait
from aiohttp import web
from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService
import utils.web
@ -63,6 +64,7 @@ class ChatComplete:
return await utils.web.api_response(1, conversation_chunk_list, request=request)
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk(request: web.Request):
@ -111,6 +113,7 @@ class ChatComplete:
return await utils.web.api_response(1, chunk_dict, request=request)
@staticmethod
@utils.web.token_auth
async def fork_conversation(request: web.Request):
@ -223,6 +226,7 @@ class ChatComplete:
"conversation_id": new_conversation.id,
}, request=request)
@staticmethod
@utils.web.token_auth
async def get_tokens(request: web.Request):
@ -240,14 +244,19 @@ class ChatComplete:
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod
@utils.web.token_auth
async def get_point_cost(request: web.Request):
async def get_point_usage(request: web.Request):
params = await utils.web.get_param(request, {
"question": {
"type": str,
"required": True,
},
"bot_id": {
"type": str,
"required": True,
},
"extract_limit": {
"type": int,
"required": False,
@ -256,39 +265,41 @@ class ChatComplete:
})
user_id = request.get("user")
caller = request.get("caller")
question = params.get("question")
extract_limit = params.get("extract_limit")
bot_id = params.get("bot_id")
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create()
db = await DatabaseService.create(request.app)
tokens = await tiktoken.get_tokens(question)
try:
res = await mwapi.ai_toolbox_get_point_cost(user_id, "chatcomplete", tokens, extract_limit)
estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
predict_tokens = tokens + estimated_extract_tokens_per_doc
async with BotPersonaHelper(db) as bot_persona_helper:
persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
if persona_info is None:
return await utils.web.api_response(-1, error={
"code": "bot-not-found",
"message": "Bot not found."
}, http_status=404, request=request)
point_usage = calculate_point_usage(predict_tokens, persona_info.cost_fixed, persona_info.cost_fixed_tokens, persona_info.cost_per_token)
return await utils.web.api_response(1, {
"point_cost": res["point_cost"],
"point_usage": point_usage,
"tokens": tokens,
"predict_tokens": predict_tokens,
}, request=request)
except Exception as e:
err_msg = f"Error while get chat complete point cost: {e}"
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def get_persona_list(request: web.Request):
params = await utils.web.get_param(request, {
"category_id": {
"type": int,
"required": False,
},
"page": {
"type": int,
"required": False,
@ -296,13 +307,12 @@ class ChatComplete:
}
})
category_id = params.get("category_id")
page = params.get("page")
db = await DatabaseService.create(request.app)
async with BotPersonaHelper(db) as bot_persona_helper:
persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id)
page_count = await bot_persona_helper.get_page_count(category_id=category_id)
persona_list = await bot_persona_helper.get_list(page=page)
page_count = await bot_persona_helper.get_page_count()
persona_data_list = []
for persona in persona_list:
@ -312,7 +322,9 @@ class ChatComplete:
"bot_name": persona.bot_name,
"bot_avatar": persona.bot_avatar,
"bot_description": persona.bot_description,
"updated_at": persona.updated_at,
"model_id": persona.model_id,
"model_name": persona.model_name,
"cost_fixed": persona.cost_fixed
})
return await utils.web.api_response(1, {
@ -320,6 +332,7 @@ class ChatComplete:
"page_count": page_count,
}, request=request)
@staticmethod
@utils.web.token_auth
async def get_persona_info(request: web.Request):
@ -354,6 +367,7 @@ class ChatComplete:
return await utils.web.api_response(1, persona_info_res, request=request)
@staticmethod
@utils.web.token_auth
async def start_chat_complete(request: web.Request):
@ -372,7 +386,7 @@ class ChatComplete:
},
"bot_id": {
"type": str,
"required": False,
"required": True,
},
"extract_limit": {
"type": int,
@ -453,6 +467,7 @@ class ChatComplete:
"message": err_msg
}, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def chat_complete_stream(request: web.Request):
@ -541,7 +556,7 @@ class ChatComplete:
try:
ignored_keys = ["message"]
response_result = {
"point_cost": task.point_cost,
"point_usage": task.point_usage,
}
for k, v in result.items():
if k not in ignored_keys:

@ -1,6 +1,7 @@
import sys
import traceback
from aiohttp import web
from libs.config import Config
from server.model.embedding_search.title_collection import TitleCollectionHelper
from server.model.embedding_search.title_index import TitleIndexHelper
from service.database import DatabaseService
@ -53,10 +54,10 @@ class EmbeddingSearch:
page_current += 1
async with EmbeddingSearchService(db, one_title) as embedding_search:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
# if request.get("caller") == "user":
# user_id = request.get("user")
# usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
# transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index()
@ -144,10 +145,10 @@ class EmbeddingSearch:
async with EmbeddingSearchService(db, page_title) as embedding_search:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
# if request.get("caller") == "user":
# user_id = request.get("user")
# usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
# transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index()
@ -236,7 +237,7 @@ class EmbeddingSearch:
"distancelimit": {
"required": False,
"type": float,
"default": 0.6
"default": None
},
})

@ -2,12 +2,15 @@ from __future__ import annotations
import sys
import time
import traceback
from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper
from utils.local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import (
ChatCompleteService,
ChatCompleteServicePrepareResponse,
ChatCompleteServiceResponse,
calculate_point_usage,
)
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
@ -41,7 +44,7 @@ class ChatCompleteTask:
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
self.point_usage: int = 0
self.is_finished = False
self.finished_time: Optional[float] = None
@ -51,9 +54,9 @@ class ChatCompleteTask:
async def init(
self,
question: str,
bot_id: str,
conversation_id: Optional[str] = None,
edit_message_id: Optional[str] = None,
bot_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
self.tiktoken = await TikTokenService.create()
@ -69,19 +72,31 @@ class ChatCompleteTask:
extract_limit = embedding_search["limit"] or 10
estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
predict_tokens = extract_limit + estimated_extract_tokens_per_doc
async with BotPersonaHelper(self.dbs) as bot_persona_helper:
bot_persona = await bot_persona_helper.find_by_bot_id(bot_id)
self.point_usage: int = calculate_point_usage(predict_tokens, bot_persona.cost_fixed,
bot_persona.cost_fixed_tokens, bot_persona.cost_per_token)
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(
self.user_id, "chatcomplete", question_tokens, extract_limit
self.user_id, "chatcomplete",
bot_id=bot_id,
tokens=predict_tokens,
point_usage=self.point_usage
)
self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
res = await self.chat_complete.prepare_chat_complete(
question,
conversation_id=conversation_id,
user_id=self.user_id,
question_tokens=question_tokens,
edit_message_id=edit_message_id,
bot_id=bot_id,
embedding_search=embedding_search,
@ -136,13 +151,13 @@ class ChatCompleteTask:
try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
await self.chat_complete.set_latest_point_usage(self.point_usage)
self.result = chat_res
if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction(
self.transatcion_id, chat_res["total_tokens"]
self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分
)
await self._on_finished()

@ -5,10 +5,9 @@ import sqlalchemy
from server.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped
from sqlalchemy import select
from sqlalchemy.orm import mapped_column, load_only, Mapped
from server.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
from service.database import DatabaseService
@ -22,17 +21,16 @@ class BotPersonaModel(BaseModel):
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
category_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(
BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"
),
index=True,
)
api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
model_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True, index=True)
model_name: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
cost_fixed: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True)
cost_fixed_tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
cost_per_token: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True)
order: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
class BotPersonaHelper(BaseHelper):
@ -50,8 +48,7 @@ class BotPersonaHelper(BaseHelper):
async def get_list(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 20,
category_id: Optional[int] = None,
page_size: Optional[int] = 20
):
offset_index = (page - 1) * page_size
@ -64,25 +61,22 @@ class BotPersonaHelper(BaseHelper):
BotPersonaModel.bot_name,
BotPersonaModel.bot_avatar,
BotPersonaModel.bot_description,
BotPersonaModel.updated_at,
BotPersonaModel.model_id,
BotPersonaModel.model_name,
BotPersonaModel.cost_fixed,
BotPersonaModel.order,
)
)
.order_by(BotPersonaModel.updated_at.desc())
.order_by(BotPersonaModel.order.desc())
.offset(offset_index)
.limit(page_size)
)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
return await self.session.scalars(stmt)
async def get_page_count(self, page_size=50, category_id: Optional[int] = None):
async def get_page_count(self, page_size=50):
stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
item_count = await self.session.scalar(stmt)
if item_count is None:
item_count = 0

@ -1,43 +0,0 @@
from __future__ import annotations
import sqlalchemy
from server.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
class BotPersonaCategoryModel(BaseModel):
__tablename__ = "chat_complete_bot_persona_category"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
font_icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
color: Mapped[str] = mapped_column(sqlalchemy.String(10), nullable=True)
order: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaCategoryModel):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
async def update(self, obj: BotPersonaCategoryModel):
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def remove(self, id: int):
stmt = sqlalchemy.delete(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
await self.session.execute(stmt)
await self.session.commit()
async def get_list(self):
stmt = select(BotPersonaCategoryModel)
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
return await self.session.scalar(stmt)

@ -6,7 +6,7 @@ import asyncpg
from server.model.base import BaseModel
import numpy as np
import sqlalchemy
from lib.config import Config
from libs.config import Config
from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession

@ -8,7 +8,7 @@ import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine
from lib.config import Config
from libs.config import Config
from server.model.base import BaseHelper, BaseModel
from service.database import DatabaseService

@ -9,7 +9,7 @@ def register_route(app: web.Application):
web.route('POST', '/chatcomplete/conversation/fork', ChatComplete.fork_conversation),
web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost),
web.route('POST', '/chatcomplete/get_point_usage', ChatComplete.get_point_usage),
web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list),
web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info),
])

@ -1,4 +1,5 @@
from __future__ import annotations
import math
import time
import traceback
from typing import Optional, Tuple, TypedDict
@ -11,7 +12,7 @@ from server.model.chat_complete.conversation import (
import sys
from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel
from lib.config import Config
from libs.config import Config
import utils.config, utils.web
from aiohttp import web
@ -72,6 +73,7 @@ class ChatCompleteService:
self.question = ""
self.question_tokens: Optional[int] = None
self.bot_id: str = ""
self.model: Optional[str] = None
self.conversation_id: Optional[int] = None
self.conversation_start_time: Optional[int] = None
@ -104,11 +106,11 @@ class ChatCompleteService:
async def prepare_chat_complete(
self,
question: str,
bot_id: str,
conversation_id: Optional[str] = None,
user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
edit_message_id: Optional[str] = None,
bot_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
if user_id is not None:
@ -117,7 +119,7 @@ class ChatCompleteService:
self.user_id = user_id
self.question = question
self.conversation_start_time = int(time.time())
self.bot_id = bot_id or None
self.bot_id = bot_id
self.conversation_info = None
if conversation_id is not None:
@ -153,14 +155,17 @@ class ChatCompleteService:
if bot_persona is None:
self.bot_id = "default"
self.model = None
bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
else:
self.model = bot_persona.model_id
self.chat_system_prompt = bot_persona.system_prompt
default_api = Config.get("chatcomplete.default_api", None, str)
try:
self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api)
except OpenAIApiTypeInvalidException:
print(f"Invalid API type: {bot_persona.api_id}", file=sys.stderr)
self.openai_api = OpenAIApi.create(default_api)
self.conversation_chunk = None
@ -236,6 +241,7 @@ class ChatCompleteService:
init_message_data = []
if bot_persona is not None:
current_time = int(time.time())
if bot_persona.message_log is not None:
for message in bot_persona.message_log:
message["id"] = utils.web.generate_uuid()
message["time"] = current_time
@ -319,11 +325,11 @@ class ChatCompleteService:
# Start chat complete
if on_message is not None:
response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, message_log, on_message
self.question, system_prompt, self.model, message_log, on_message
)
else:
response = await self.openai_api.chat_complete(
self.question, system_prompt, message_log
self.question, system_prompt, self.model, message_log
)
description = response["message"][0:150]
@ -384,7 +390,7 @@ class ChatCompleteService:
delta_data=delta_data,
)
async def set_latest_point_cost(self, point_cost: int) -> bool:
async def set_latest_point_usage(self, point_usage: int) -> bool:
if self.conversation_chunk is None:
return False
@ -393,7 +399,7 @@ class ChatCompleteService:
for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
if self.conversation_chunk.message_data[i].get("role") == "assistant":
self.conversation_chunk.message_data[i]["point_cost"] = point_cost
self.conversation_chunk.message_data[i]["point_usage"] = point_usage
flag_modified(self.conversation_chunk, "message_data")
await self.conversation_chunk_helper.update(self.conversation_chunk)
@ -402,6 +408,10 @@ class ChatCompleteService:
async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
api_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_api_id", "default", str)
openai_api = OpenAIApi.create(api_id)
for message_data in message_log_list:
if "content" in message_data:
@ -422,8 +432,8 @@ class ChatCompleteService:
"make_summary", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(
summary_prompt, summary_system_prompt
response = await openai_api.chat_complete(
summary_prompt, summary_system_prompt, model_id
)
return response["message"], response["message_tokens"]
@ -431,6 +441,10 @@ class ChatCompleteService:
async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
api_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_api_id", "default", str)
openai_api = OpenAIApi.create(api_id)
for message_data in message_log_list:
if "content" in message_data:
@ -451,7 +465,18 @@ class ChatCompleteService:
"make_title", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(
title_prompt, title_system_prompt
response = await openai_api.chat_complete(
title_prompt, title_system_prompt, model_id
)
if response["message"] is None:
print(response)
raise Exception("Title generation failed")
return response["message"][0:250], response["message_tokens"]
def calculate_point_usage(tokens: int, cost_fixed: float, cost_fixed_tokens: int, cost_per_token: float) -> int:
if tokens <= cost_fixed_tokens:
return cost_fixed
return cost_fixed + math.ceil((tokens - cost_fixed_tokens) * cost_per_token)

@ -4,7 +4,7 @@ from urllib.parse import quote_plus
from aiohttp import web
import asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from lib.config import Config
from libs.config import Config
def get_dsn():
db_conf = Config.get("database")

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Optional, TypedDict
import sqlalchemy
from libs.config import Config
from server.model.embedding_search.title_collection import (
TitleCollectionHelper,
TitleCollectionModel,
@ -10,7 +11,6 @@ from server.model.embedding_search.title_index import TitleIndexHelper, TitleInd
from server.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.text_embedding import TextEmbeddingService
from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences
@ -43,7 +43,6 @@ class EmbeddingSearchService:
self.tiktoken: TikTokenService = None
self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.page_id: Optional[int] = None
self.collection_id: Optional[int] = None
@ -284,11 +283,14 @@ class EmbeddingSearchService:
query: str,
limit: int = 10,
in_collection: bool = False,
distance_limit: float = 0.6,
distance_limit: Optional[float] = None,
):
if limit == 0:
return [], 0
if distance_limit is None:
distance_limit = Config.get("embedding.default_distance_limit")
if self.page_index is None:
raise Exception("Page index is not initialized")
@ -296,12 +298,15 @@ class EmbeddingSearchService:
query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"]
print(query_embedding)
if query_embedding is None:
return [], token_usage
res = await self.page_index.search_text_embedding(
query_embedding, in_collection, limit, self.page_id
)
print(res)
if res:
filtered = []
for one in res:

@ -4,7 +4,7 @@ import sys
import time
from typing import Optional, TypedDict
import aiohttp
from lib.config import Config
from libs.config import Config
class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:
@ -41,10 +41,9 @@ class GetAllPagesResponse(TypedDict):
continue_key: Optional[str]
class ChatCompleteGetPointUsageResponse(TypedDict):
point_cost: int
point_usage: int
class ChatCompleteReportUsageResponse(TypedDict):
point_cost: int
transaction_id: str
class MediaWikiApi:
@ -58,6 +57,7 @@ class MediaWikiApi:
return MediaWikiApi.instance
def __init__(self, api_url: str):
self.api_url = api_url
self.request_proxy = Config.get("request.proxy", type=str, empty_is_none=True)
@ -66,6 +66,7 @@ class MediaWikiApi:
self.login_identity = None
self.login_time = 0.0
async def get_page_info(self, title: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -86,6 +87,7 @@ class MediaWikiApi:
return data["query"]["pages"][0]
async def parse_page(self, title: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -105,6 +107,7 @@ class MediaWikiApi:
return data["parse"]["text"]
async def get_site_meta(self):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -131,6 +134,7 @@ class MediaWikiApi:
return ret
async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse:
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -163,6 +167,7 @@ class MediaWikiApi:
return ret
async def is_logged_in(self):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -178,6 +183,7 @@ class MediaWikiApi:
return data["query"]["userinfo"]["id"] != 0
async def get_token(self, token_type: str):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -194,6 +200,7 @@ class MediaWikiApi:
return data["query"]["tokens"][token_type + "token"]
async def robot_login(self, username: str, password: str):
if await self.is_logged_in():
self.login_time = time.time()
@ -225,6 +232,7 @@ class MediaWikiApi:
return True
async def refresh_login(self):
if self.login_identity is None:
print("刷新MW机器人账号登录状态失败没有保存的用户")
@ -233,6 +241,7 @@ class MediaWikiApi:
print("刷新MW机器人账号登录状态")
return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
async def search_title(self, keyword: str) -> list[str]:
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
@ -245,6 +254,7 @@ class MediaWikiApi:
data = await resp.json()
return data[1]
async def ai_toolbox_get_user_info(self, user_id: int):
await self.refresh_login()
@ -266,32 +276,9 @@ class MediaWikiApi:
return data["aitoolboxbot"]["userinfo"]
async def ai_toolbox_get_point_cost(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteGetPointUsageResponse:
await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
post_data = {
"action": "aitoolboxbot",
"method": "reportusage",
"step": "check",
"userid": int(user_id) if user_id is not None else None,
"useraction": user_action,
"tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None,
"format": "json",
"formatversion": "2",
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
point_cost = int(data["aitoolboxbot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteGetPointUsageResponse(point_cost=point_cost)
async def ai_toolbox_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse:
async def ai_toolbox_start_transaction(self, user_id: int, user_action: str, bot_id: Optional[str] = None,
tokens: Optional[int] = None, point_usage: Optional[int] = None) -> ChatCompleteReportUsageResponse:
await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -301,8 +288,9 @@ class MediaWikiApi:
"step": "start",
"userid": int(user_id) if user_id is not None else None,
"useraction": user_action,
"botid": bot_id,
"tokens": int(tokens) if tokens is not None else None,
"extractlines": int(extractlines) if extractlines is not None else None,
"pointusage": int(point_usage) if point_usage is not None else 0,
"format": "json",
"formatversion": "2",
}
@ -316,11 +304,10 @@ class MediaWikiApi:
else:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
point_cost = int(data["aitoolboxbot"]["reportusage"]["pointcost"] or 0)
return ChatCompleteReportUsageResponse(point_cost=point_cost,
transaction_id=data["aitoolboxbot"]["reportusage"]["transactionid"])
return ChatCompleteReportUsageResponse(transaction_id=data["aitoolboxbot"]["reportusage"]["transactionid"])
async def ai_toolbox_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
async def ai_toolbox_end_transaction(self, transaction_id: str, point_usage: Optional[int] = None, tokens: Optional[int] = None):
await self.refresh_login()
try:
@ -330,7 +317,7 @@ class MediaWikiApi:
"method": "reportusage",
"step": "end",
"transactionid": transaction_id,
"tokens": tokens,
"pointusage": int(point_usage) if point_usage is not None else 0,
"format": "json",
"formatversion": "2",
}
@ -345,6 +332,7 @@ class MediaWikiApi:
except Exception as e:
print(e, file=sys.stderr)
async def ai_toolbox_cancel_transaction(self, transaction_id: str, error: Optional[str] = None):
await self.refresh_login()

@ -4,7 +4,7 @@ import json
from typing import Callable, Optional, TypedDict
import aiohttp
from lib.config import Config
from libs.config import Config
import numpy as np
from aiohttp_sse_client2 import client as sse_client
@ -65,11 +65,18 @@ class OpenAIApi:
def get_url(self, method: str):
if self.api_type == "azure":
deployments = Config.get(f"chatcomplete.{self.api_id}.deployments")
if method == "chat/completions":
return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_chatcomplete")
if deployment is None:
raise Exception("deployment for chatcomplete is not set")
return self.api_url + "/openai/deployments/" + deployment + "/" + method
elif method == "embeddings":
return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method
deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_embedding")
if deployment is None:
raise Exception("deployment for embedding is not set")
return self.api_url + "/openai/deployments/" + deployment + "/" + method
else:
return self.api_url + "/v1/" + method
@ -175,7 +182,7 @@ class OpenAIApi:
return messageList
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation)
url = self.get_url("chat/completions")
@ -189,7 +196,7 @@ class OpenAIApi:
if self.api_type == "azure":
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
else:
post_data["model"] = "gpt-3.5-turbo"
post_data["model"] = model
post_data = {k: v for k, v in post_data.items() if v is not None}
@ -218,10 +225,13 @@ class OpenAIApi:
message_tokens=message_tokens,
total_tokens=total_tokens,
finish_reason=finish_reason)
else:
print(data)
raise Exception("Invalid response from chat complete api")
return None
async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
tiktoken = await TikTokenService.create()
messageList = await self.make_message_list(question, system_prompt, conversation)
@ -247,7 +257,7 @@ class OpenAIApi:
if self.api_type == "azure":
params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION
else:
post_data["model"] = "gpt-3.5-turbo"
post_data["model"] = model
post_data = {k: v for k, v in post_data.items() if v is not None}

@ -1,7 +1,7 @@
from __future__ import annotations
from copy import deepcopy
import time
from lib.config import Config
from libs.config import Config
import asyncio
import random
import threading
@ -22,8 +22,8 @@ class Text2VecEmbeddingQueueTaskInfo(TypedDict):
class Text2VecEmbeddingQueue:
def __init__(self, model: str) -> None:
self.model_name = model
self.embedding_model: SentenceModel = None
self.embedding_model = SentenceModel(self.model_name)
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
@ -32,6 +32,10 @@ class Text2VecEmbeddingQueue:
self.running = False
def post_init(self):
self.embedding_model = SentenceModel(self.model_name)
async def get_embeddings(self, text: str):
task_id = random.randint(0, 1000000000)
with self.lock:
@ -54,6 +58,7 @@ class Text2VecEmbeddingQueue:
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
@ -64,6 +69,7 @@ class Text2VecEmbeddingQueue:
return None
def run(self):
running = True
last_task_time = None
@ -88,15 +94,22 @@ class Text2VecEmbeddingQueue:
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
class TextEmbeddingService:
instance = None
def __init__(self):
self.tiktoken: TikTokenService = None
self.text2vec_queue: Text2VecEmbeddingQueue = None
self.openai_api: OpenAIApi = None
@staticmethod
async def create() -> TextEmbeddingService:
if TextEmbeddingService.instance is None:
@ -111,21 +124,27 @@ class TextEmbeddingService:
if self.embedding_type == "text2vec":
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model)
self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model)
elif self.embedding_type == "openai":
self.openai_api: OpenAIApi = await OpenAIApi.create()
api_id = Config.get("embedding.api_id")
self.openai_api: OpenAIApi = await OpenAIApi.create(api_id)
await loop.run_in_executor(None, self.embedding_queue.init)
await loop.run_in_executor(None, self.text2vec_queue.post_init)
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
total_token_usage = 0
for index, doc in enumerate(doc_list):
text = doc["text"]
embedding = await self.embedding_queue.get_embeddings(text)
embedding = await self.text2vec_queue.get_embeddings(text)
doc["embedding"] = embedding
total_token_usage += await self.tiktoken.get_tokens(text)
if on_index_progress is not None:
await on_index_progress(index, len(doc_list))
return (doc_list, total_token_usage)
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
res_doc_list = deepcopy(doc_list)

@ -4,7 +4,7 @@ import pathlib
root_path = pathlib.Path(__file__).parent.parent
sys.path.append(".")
from lib.config import Config
from libs.config import Config
Config.load_config(str(root_path) + "/config.toml")
Config.set("server.debug", True)

@ -16,7 +16,7 @@ async def main():
try:
prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, {
"distance_limit": 0.6,
"distance_limit": 20,
"limit": 10
})

@ -14,14 +14,14 @@ async def main():
print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True)
print("")
await embedding_search.update_page_index(on_index_progress)
#await embedding_search.update_page_index(on_index_progress)
print("")
while True:
query = input("请输入要搜索的问题 (.exit 退出)")
if query == ".exit":
break
res, token_usage = await embedding_search.search(query, 5)
res, token_usage = await embedding_search.search(query, 5, False, 0.1)
total_length = 0
if res:
for one in res:

@ -1,5 +1,5 @@
import time
from lib.config import Config
from libs.config import Config
def get_prompt(name: str, type: str, params: dict = {}):
sys_params = {

@ -1,5 +1,5 @@
import asyncio
from lib.noawait import NoAwaitPool
from libs.noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop()

@ -5,7 +5,7 @@ from typing import Any, Optional, Dict
from aiohttp import web
import jwt
import uuid
from lib.config import Config
from libs.config import Config
ParamRule = Dict[str, Any]
@ -15,7 +15,7 @@ class ParamInvalidException(web.HTTPBadRequest):
self.param_list = param_list
self.rules = rules
param_list_str = "'" + ("', '".join(param_list)) + "'"
super().__init__(f"Param invalid: {param_list_str}")
super().__init__(reason=f"Param invalid: {param_list_str}")
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
params: dict[str, Any] = {}

Loading…
Cancel
Save