You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
4.1 KiB
Python

import sys
import traceback
from config import Config
Config.load_config("config.toml")
from local import loop, noawait
from aiohttp import web
import local
import api.route
import utils.web
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
# Auto create Table
from api.model.base import BaseModel
from api.model.toolkit_ui.page_title import PageTitleModel as _
from api.model.toolkit_ui.conversation import ConversationModel as _
from api.model.chat_complete.conversation import ConversationChunkModel as _
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _
from api.model.chat_complete.bot_persona import BotPersonaModel as _
from api.model.embedding_search.title_collection import TitleCollectionModel as _
from api.model.embedding_search.title_index import TitleIndexModel as _
from service.tiktoken import TikTokenService
async def index(request: web.Request):
return await utils.web.api_response(
1, data={"message": "Isekai toolkit API"}, request=request
)
@web.middleware
async def error_handler(request, handler):
try:
response = await handler(request)
return response
except utils.web.ParamInvalidException as ex:
return await utils.web.api_response(
-1,
error={"code": "invalid-params", "message": "Invalid params.", "invalid_params": ex.param_list},
http_status=400,
request=request,
)
except web.HTTPException as ex:
return await utils.web.api_response(
-1,
error={"code": f"http_{ex.status}", "message": ex.reason},
http_status=ex.status,
request=request,
)
except Exception as ex:
error_id = utils.web.generate_uuid()
err_msg = f"Server error [{error_id}]: {ex}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
print(f"End of error [{error_id}]", file=sys.stderr)
return await utils.web.api_response(
-1,
error={"code": "internal-server-error", "message": err_msg},
http_status=500,
request=request,
)
async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create()
bot_username = Config.get("mediawiki.bot_username", "", str)
bot_password = Config.get("mediawiki.bot_password", "", str)
if bot_username and bot_password:
try:
await mw_api.robot_login(bot_username, bot_password)
except Exception as e:
print("Cannot login to Robot account, please check config.")
site_meta = await mw_api.get_site_meta()
print(
"Connected to Wiki %s, Robot username: %s"
% (site_meta["sitename"], site_meta["user"])
)
async def init_database(app: web.Application):
dbs = await DatabaseService.create(app)
print("Database connected.")
async with dbs.engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.create_all)
async def close_database(app: web.Application):
dbs = await DatabaseService.create(app)
await dbs.close()
async def init_tiktoken(app: web.Application):
await TikTokenService.create()
print("Tiktoken model loaded.")
async def stop_noawait_pool(app: web.Application):
await noawait.end()
if __name__ == "__main__":
local.debug = Config.get("server.debug", False, bool)
app = web.Application(
middlewares=[
error_handler,
]
)
if Config.get("database.host"):
app.on_startup.append(init_database)
app.on_cleanup.append(close_database)
if Config.get("mediawiki.api_endpoint"):
app.on_startup.append(init_mw_api)
if Config.get("chatcomplete.enabled"):
app.on_startup.append(init_tiktoken)
app.on_shutdown.append(stop_noawait_pool)
app.router.add_route("*", "/", index)
api.route.init(app)
server_port = Config.get("port", 8144, int)
web.run_app(
app,
host="0.0.0.0",
port=server_port,
loop=loop,
)