完善错误处理程序

master
落雨楓 2 years ago
parent 81562cba4d
commit 80ad7223c1

@ -13,7 +13,7 @@ from aiohttp import web
from sqlalchemy import select from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService, ChatCompleteServiceResponse from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteService, ChatCompleteServiceResponse
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
@ -282,7 +282,7 @@ class ChatComplete:
traceback.print_exc() traceback.print_exc()
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "chat-complete-error", "code": "internal-server-error",
"message": err_msg "message": err_msg
}, http_status=500, request=request) }, http_status=500, request=request)
@ -441,12 +441,20 @@ class ChatComplete:
"code": "no-enough-points", "code": "no-enough-points",
"message": error_msg "message": error_msg
}, http_status=403, request=request) }, http_status=403, request=request)
except ChatCompleteQuestionTooLongException as e:
error_msg = "Question too long."
return await utils.web.api_response(-1, error={
"code": "question-too-long",
"limit": e.tokens_limit,
"current": e.tokens_current,
"message": error_msg
}, http_status=400, request=request)
except Exception as e: except Exception as e:
err_msg = f"Error while processing chat complete request: {e}" err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc() traceback.print_exc()
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "chat-complete-error", "code": "internal-server-error",
"message": err_msg "message": err_msg
}, http_status=500, request=request) }, http_status=500, request=request)
@ -505,7 +513,7 @@ class ChatComplete:
'status': -1, 'status': -1,
'message': str(task.error), 'message': str(task.error),
'error': { 'error': {
'code': "internal-error", 'code': "internal-server-error",
'info': str(task.error), 'info': str(task.error),
}, },
}) })
@ -548,7 +556,7 @@ class ChatComplete:
'status': -1, 'status': -1,
'message': str(err), 'message': str(err),
'error': { 'error': {
'code': "internal-error", 'code': "internal-server-error",
'info': str(err), 'info': str(err),
}, },
}) })

@ -1,4 +1,7 @@
import sys
import traceback
from config import Config from config import Config
Config.load_config("config.toml") Config.load_config("config.toml")
from local import loop, noawait from local import loop, noawait
@ -21,8 +24,46 @@ from api.model.embedding_search.title_index import TitleIndexModel as _
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
async def index(request: web.Request): async def index(request: web.Request):
return await utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request) return await utils.web.api_response(
1, data={"message": "Isekai toolkit API"}, request=request
)
@web.middleware
async def error_handler(request, handler):
try:
response = await handler(request)
return response
except utils.web.ParamInvalidException as ex:
return await utils.web.api_response(
-1,
error={"code": "invalid-params", "message": "Invalid params.", "invalid_params": ex.param_list},
http_status=400,
request=request,
)
except web.HTTPException as ex:
return await utils.web.api_response(
-1,
error={"code": f"http_{ex.status}", "message": ex.reason},
http_status=ex.status,
request=request,
)
except Exception as ex:
error_id = utils.web.generate_uuid()
err_msg = f"Server error [{error_id}]: {ex}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
print(f"End of error [{error_id}]", file=sys.stderr)
return await utils.web.api_response(
-1,
error={"code": "internal-server-error", "message": err_msg},
http_status=500,
request=request,
)
async def init_mw_api(app: web.Application): async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create() mw_api = MediaWikiApi.create()
@ -38,41 +79,60 @@ async def init_mw_api(app: web.Application):
site_meta = await mw_api.get_site_meta() site_meta = await mw_api.get_site_meta()
print("Connected to Wiki %s, Robot username: %s" % (site_meta["sitename"], site_meta["user"])) print(
"Connected to Wiki %s, Robot username: %s"
% (site_meta["sitename"], site_meta["user"])
)
async def init_database(app: web.Application): async def init_database(app: web.Application):
dbs = await DatabaseService.create(app) dbs = await DatabaseService.create(app)
print("Database connected.") print("Database connected.")
async with dbs.engine.begin() as conn: async with dbs.engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.create_all) await conn.run_sync(BaseModel.metadata.create_all)
async def close_database(app: web.Application):
dbs = await DatabaseService.create(app)
await dbs.close()
async def init_tiktoken(app: web.Application): async def init_tiktoken(app: web.Application):
await TikTokenService.create() await TikTokenService.create()
print("Tiktoken model loaded.") print("Tiktoken model loaded.")
async def stop_noawait_pool(app: web.Application): async def stop_noawait_pool(app: web.Application):
await noawait.end() await noawait.end()
if __name__ == '__main__': if __name__ == "__main__":
local.debug = Config.get("server.debug", False, bool) local.debug = Config.get("server.debug", False, bool)
app = web.Application() app = web.Application(
middlewares=[
error_handler,
]
)
if Config.get("database.host"): if Config.get("database.host"):
app.on_startup.append(init_database) app.on_startup.append(init_database)
app.on_cleanup.append(close_database)
if Config.get("mediawiki.api_endpoint"): if Config.get("mediawiki.api_endpoint"):
app.on_startup.append(init_mw_api) app.on_startup.append(init_mw_api)
if Config.get("chatcomplete.enabled"): if Config.get("chatcomplete.enabled"):
app.on_startup.append(init_tiktoken) app.on_startup.append(init_tiktoken)
app.on_shutdown.append(stop_noawait_pool) app.on_shutdown.append(stop_noawait_pool)
app.router.add_route('*', '/', index) app.router.add_route("*", "/", index)
api.route.init(app) api.route.init(app)
server_port = Config.get("port", 8144, int) server_port = Config.get("port", 8144, int)
web.run_app(app, host='0.0.0.0', port=server_port, loop=loop) web.run_app(
app,
host="0.0.0.0",
port=server_port,
loop=loop,
)

@ -25,6 +25,12 @@ from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
class ChatCompleteQuestionTooLongException(Exception):
def __init__(self, tokens_limit: int, tokens_current: int):
super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
self.tokens_limit = tokens_limit
self.tokens_current = tokens_current
class ChatCompleteServicePrepareResponse(TypedDict): class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list extract_doc: list
question_tokens: int question_tokens: int
@ -139,7 +145,7 @@ class ChatCompleteService:
self.question_tokens > max_input_tokens self.question_tokens > max_input_tokens
): ):
# If the question is too long, we need to truncate it # If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge(max_input_tokens, self.question_tokens) raise ChatCompleteQuestionTooLongException(max_input_tokens, self.question_tokens)
if self.conversation_info is not None: if self.conversation_info is not None:
self.bot_id = self.conversation_info.extra.get("bot_id") or "default" self.bot_id = self.conversation_info.extra.get("bot_id") or "default"

@ -46,4 +46,8 @@ class DatabaseService:
engine = create_async_engine(get_dsn(), echo=local.debug) engine = create_async_engine(get_dsn(), echo=local.debug)
self.engine = engine self.engine = engine
self.create_session = async_sessionmaker(engine, expire_on_commit=False) self.create_session = async_sessionmaker(engine, expire_on_commit=False)
async def close(self):
await self.engine.dispose()
await self.pool.close()

@ -15,15 +15,7 @@ class ParamInvalidException(web.HTTPBadRequest):
self.param_list = param_list self.param_list = param_list
self.rules = rules self.rules = rules
param_list_str = "'" + ("', '".join(param_list)) + "'" param_list_str = "'" + ("', '".join(param_list)) + "'"
super().__init__(f"Param invalid: {param_list_str}", super().__init__(f"Param invalid: {param_list_str}")
content_type="application/json",
body=json.dumps({
"status": -1,
"error": {
"code": self.code,
"message": f"Param invalid: {param_list_str}"
}
}))
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None): async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
params: dict[str, Any] = {} params: dict[str, Any] = {}

Loading…
Cancel
Save