重新创建项目
commit
e21a28a85f
@ -0,0 +1,144 @@
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
config.py
|
@ -0,0 +1,326 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from aiohttp import WSMsgType, web
|
||||
from sqlalchemy import select
|
||||
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
|
||||
from service.chat_complete import ChatCompleteService
|
||||
from service.database import DatabaseService
|
||||
from service.mediawiki_api import MediaWikiApi
|
||||
from service.tiktoken import TikTokenService
|
||||
import utils.web
|
||||
|
||||
|
||||
class ChatCompleteWebSocketController:
|
||||
def __init__(self, request: web.Request):
|
||||
self.request = request
|
||||
self.ws = None
|
||||
self.db = None
|
||||
self.chat_complete = None
|
||||
|
||||
self.closed = False
|
||||
|
||||
self.refreshed_time = 0
|
||||
|
||||
async def run(self):
|
||||
self.ws = web.WebSocketResponse()
|
||||
await self.ws.prepare(self.request)
|
||||
self.refreshed_time = time.time()
|
||||
|
||||
self.db = await DatabaseService.create(self.request.app)
|
||||
|
||||
self.query = self.request.query
|
||||
if self.request.get("caller") == "user":
|
||||
user_id = self.request.get("user")
|
||||
else:
|
||||
user_id = self.query.get("user_id")
|
||||
title = self.query.get("title")
|
||||
|
||||
# create heartbeat task
|
||||
asyncio.ensure_future(self._timeout_task())
|
||||
|
||||
async for msg in self.ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event = data.get('event')
|
||||
self.refreshed_time = time.time()
|
||||
if event == 'chatcomplete':
|
||||
asyncio.ensure_future(self._chatcomplete(data))
|
||||
if event == 'ping':
|
||||
await self.ws.send_json({
|
||||
'event': 'pong'
|
||||
})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
traceback.print_exc()
|
||||
await self.ws.send_json({
|
||||
'event': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
print('ws connection closed with exception %s' %
|
||||
self.ws.exception())
|
||||
|
||||
async def _timeout_task(self):
|
||||
while not self.closed:
|
||||
if time.time() - self.refreshed_time > 30:
|
||||
self.closed = True
|
||||
await self.ws.close()
|
||||
return
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _chatcomplete(self, params: dict):
|
||||
question = params.get("question")
|
||||
conversation_id = params.get("conversation_id")
|
||||
|
||||
|
||||
class ChatComplete:
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def get_conversation_chunk_list(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"user_id": {
|
||||
"required": False,
|
||||
"type": int
|
||||
},
|
||||
"conversation_id": {
|
||||
"required": True,
|
||||
"type": int
|
||||
}
|
||||
})
|
||||
|
||||
if request.get("caller") == "user":
|
||||
user_id = request.get("user")
|
||||
else:
|
||||
user_id = params.get("user_id")
|
||||
|
||||
conversation_id = params.get("conversation_id")
|
||||
|
||||
db = await DatabaseService.create(request.app)
|
||||
|
||||
async with db.create_session() as session:
|
||||
stmt = select(ConversationModel).where(
|
||||
ConversationModel.id == conversation_id)
|
||||
|
||||
conversation_data = await session.scalar(stmt)
|
||||
|
||||
if conversation_data is None:
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "conversation-not-found",
|
||||
"message": "Conversation not found."
|
||||
}, http_status=404, request=request)
|
||||
|
||||
if conversation_data.user_id != user_id:
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "permission-denied",
|
||||
"message": "Permission denied."
|
||||
}, http_status=403, request=request)
|
||||
|
||||
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
|
||||
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
|
||||
|
||||
conversation_chunk_result = await session.scalars(stmt)
|
||||
|
||||
conversation_chunk_list = []
|
||||
|
||||
for result in conversation_chunk_result:
|
||||
conversation_chunk_list.append({
|
||||
"id": result.id,
|
||||
"updated_at": result.updated_at
|
||||
})
|
||||
|
||||
return await utils.web.api_response(1, conversation_chunk_list, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def get_conversation_chunk(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"user_id": {
|
||||
"required": False,
|
||||
"type": int,
|
||||
},
|
||||
"chunk_id": {
|
||||
"required": True,
|
||||
"type": int,
|
||||
},
|
||||
})
|
||||
|
||||
if request.get("caller") == "user":
|
||||
user_id = request.get("user")
|
||||
else:
|
||||
user_id = params.get("user_id")
|
||||
|
||||
chunk_id = params.get("chunk_id")
|
||||
|
||||
dbs = await DatabaseService.create(request.app)
|
||||
async with dbs.create_session() as session:
|
||||
stmt = select(ConversationChunkModel).where(
|
||||
ConversationChunkModel.id == chunk_id)
|
||||
|
||||
conversation_data = await session.scalar(stmt)
|
||||
|
||||
if conversation_data is None:
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "conversation-chunk-not-found",
|
||||
"message": "Conversation chunk not found."
|
||||
}, http_status=404, request=request)
|
||||
|
||||
if conversation_data.conversation.user_id != user_id:
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "permission-denied",
|
||||
"message": "Permission denied."
|
||||
}, http_status=403, request=request)
|
||||
|
||||
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def get_tokens(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"question": {
|
||||
"type": str,
|
||||
"required": True
|
||||
}
|
||||
})
|
||||
|
||||
question = params.get("question")
|
||||
|
||||
tiktoken = await TikTokenService.create()
|
||||
tokens = await tiktoken.get_tokens(question)
|
||||
|
||||
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def chat_complete(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"title": {
|
||||
"type": str,
|
||||
"required": True,
|
||||
},
|
||||
"question": {
|
||||
"type": str,
|
||||
"required": True,
|
||||
},
|
||||
"conversation_id": {
|
||||
"type": int,
|
||||
"required": False,
|
||||
},
|
||||
"extra_limit": {
|
||||
"type": int,
|
||||
"required": False,
|
||||
"default": 10,
|
||||
},
|
||||
"in_collection": {
|
||||
"type": bool,
|
||||
"required": False,
|
||||
"default": False,
|
||||
},
|
||||
})
|
||||
|
||||
user_id = request.get("user")
|
||||
caller = request.get("caller")
|
||||
|
||||
page_title = params.get("title")
|
||||
question = params.get("question")
|
||||
conversation_id = params.get("conversation_id")
|
||||
|
||||
extra_limit = params.get("extra_limit")
|
||||
in_collection = params.get("in_collection")
|
||||
|
||||
dbs = await DatabaseService.create(request.app)
|
||||
tiktoken = await TikTokenService.create()
|
||||
mwapi = MediaWikiApi.create()
|
||||
if utils.web.is_websocket(request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
try:
|
||||
async with ChatCompleteService(dbs, page_title) as chat_complete_service:
|
||||
if await chat_complete_service.page_index_exists():
|
||||
tokens = await tiktoken.get_tokens(question)
|
||||
|
||||
transatcion_id = None
|
||||
if request.get("caller") == "user":
|
||||
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
|
||||
|
||||
async def on_message(text: str):
|
||||
# Send message to client, start with "+" to indicate it's a message
|
||||
# use json will make the package 10x larger
|
||||
await ws.send_str("+" + text)
|
||||
|
||||
async def on_extracted_doc(doc: list):
|
||||
await ws.send_json({
|
||||
'event': 'extract_doc',
|
||||
'status': 1,
|
||||
'doc': doc
|
||||
})
|
||||
|
||||
try:
|
||||
chat_res = await chat_complete_service \
|
||||
.chat_complete(question, on_message, on_extracted_doc,
|
||||
conversation_id=conversation_id, user_id=user_id, embedding_search={
|
||||
"limit": extra_limit,
|
||||
"in_collection": in_collection,
|
||||
})
|
||||
await ws.send_json({
|
||||
'event': 'done',
|
||||
'status': 1,
|
||||
**chat_res,
|
||||
})
|
||||
|
||||
if transatcion_id:
|
||||
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
|
||||
except Exception as e:
|
||||
err_msg = f"Error while processing chat complete request: {e}"
|
||||
traceback.print_exc()
|
||||
|
||||
if not ws.closed:
|
||||
await ws.send_json({
|
||||
'event': 'error',
|
||||
'status': -1,
|
||||
'message': err_msg,
|
||||
'error': {
|
||||
'code': 'internal_error',
|
||||
'title': page_title,
|
||||
},
|
||||
})
|
||||
if transatcion_id:
|
||||
result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
|
||||
else:
|
||||
await ws.send_json({
|
||||
'event': 'error',
|
||||
'status': -2,
|
||||
'message': "Page index not found.",
|
||||
'error': {
|
||||
'code': 'page_not_found',
|
||||
'title': page_title,
|
||||
},
|
||||
})
|
||||
|
||||
# websocket closed
|
||||
except Exception as e:
|
||||
err_msg = f"Error while processing chat complete request: {e}"
|
||||
traceback.print_exc()
|
||||
|
||||
if not ws.closed:
|
||||
await ws.send_json({
|
||||
'event': 'error',
|
||||
'status': -1,
|
||||
'message': err_msg,
|
||||
'error': {
|
||||
'code': 'internal_error',
|
||||
'title': page_title,
|
||||
},
|
||||
})
|
||||
finally:
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
else:
|
||||
return await utils.web.api_response(-1, request=request, error={
|
||||
"code": "protocol-mismatch",
|
||||
"message": "Protocol mismatch, websocket request expected."
|
||||
}, http_status=400)
|
@ -0,0 +1,229 @@
|
||||
import sys
|
||||
import traceback
|
||||
from aiohttp import web
|
||||
from service.database import DatabaseService
|
||||
from service.embedding_search import EmbeddingSearchService
|
||||
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
||||
import utils.web
|
||||
|
||||
class EmbeddingSearch:
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def index_page(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"title": {
|
||||
"required": True,
|
||||
},
|
||||
})
|
||||
|
||||
page_title = params.get('title')
|
||||
|
||||
mwapi = MediaWikiApi.create()
|
||||
db = await DatabaseService.create(request.app)
|
||||
# Detect is WebSocket
|
||||
if utils.web.is_websocket(request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
try:
|
||||
transatcion_id = None
|
||||
|
||||
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
||||
if await embedding_search.should_update_page_index():
|
||||
if request.get("caller") == "user":
|
||||
user_id = request.get("user")
|
||||
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
|
||||
|
||||
await embedding_search.prepare_update_index()
|
||||
|
||||
async def on_progress(current, total):
|
||||
|
||||
await ws.send_json({
|
||||
'event': 'progress',
|
||||
'current': current,
|
||||
'total': total
|
||||
})
|
||||
|
||||
token_usage = await embedding_search.update_page_index(on_progress)
|
||||
await ws.send_json({
|
||||
'event': 'done',
|
||||
'status': 1,
|
||||
'index_updated': True
|
||||
})
|
||||
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_end_transaction(transatcion_id, token_usage)
|
||||
else:
|
||||
await ws.send_json({
|
||||
'event': 'done',
|
||||
'status': 1,
|
||||
'index_updated': False
|
||||
})
|
||||
except MediaWikiPageNotFoundException:
|
||||
error_msg = "Page \"%s\" not found." % page_title
|
||||
await ws.send_json({
|
||||
'event': 'error',
|
||||
'status': -2,
|
||||
'message': error_msg,
|
||||
'error': {
|
||||
'code': 'page_not_found',
|
||||
'title': page_title,
|
||||
},
|
||||
})
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
||||
except MediaWikiApiException as e:
|
||||
error_msg = "MediaWiki API error: %s" % str(e)
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
await ws.send_json({
|
||||
'event': 'error',
|
||||
'status': -3,
|
||||
'message': error_msg,
|
||||
'error': {
|
||||
'code': e.code,
|
||||
'info': e.info,
|
||||
},
|
||||
})
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
await ws.send_json({
|
||||
'event': 'error',
|
||||
'status': -1,
|
||||
'message': error_msg
|
||||
})
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
||||
finally:
|
||||
await ws.close()
|
||||
else:
|
||||
try:
|
||||
transatcion_id = None
|
||||
|
||||
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
||||
if await embedding_search.should_update_page_index():
|
||||
if request.get("caller") == "user":
|
||||
user_id = request.get("user")
|
||||
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
|
||||
|
||||
await embedding_search.prepare_update_index()
|
||||
|
||||
token_usage = await embedding_search.update_page_index()
|
||||
|
||||
if transatcion_id:
|
||||
result = await mwapi.chat_complete_end_transaction(transatcion_id, token_usage)
|
||||
|
||||
return await utils.web.api_response(1, {"data_indexed": True})
|
||||
else:
|
||||
return await utils.web.api_response(1, {"data_indexed": False})
|
||||
except MediaWikiPageNotFoundException:
|
||||
error_msg = "Page \"%s\" not found." % page_title
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
||||
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "page-not-found",
|
||||
"title": page_title,
|
||||
"message": error_msg
|
||||
}, http_status=404)
|
||||
except MediaWikiApiException as e:
|
||||
error_msg = "MediaWiki API error: %s" % e.info
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
||||
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "mediawiki-api-error",
|
||||
"info": e.info,
|
||||
"message": error_msg
|
||||
}, http_status=500)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if transatcion_id:
|
||||
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
||||
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "internal-server-error",
|
||||
"message": error_msg
|
||||
}, http_status=500)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def search(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"title": {
|
||||
"required": True,
|
||||
},
|
||||
"query": {
|
||||
"required": True
|
||||
},
|
||||
"limit": {
|
||||
"required": False,
|
||||
"type": int,
|
||||
"default": 5
|
||||
},
|
||||
"incollection": {
|
||||
"required": False,
|
||||
"type": bool,
|
||||
"default": False
|
||||
},
|
||||
"distancelimit": {
|
||||
"required": False,
|
||||
"type": float,
|
||||
"default": 0.6
|
||||
},
|
||||
})
|
||||
|
||||
page_title = params.get('title')
|
||||
query = params.get('query')
|
||||
limit = params.get('limit')
|
||||
in_collection = params.get('incollection')
|
||||
distance_limit = params.get('distancelimit')
|
||||
|
||||
limit = min(limit, 10)
|
||||
|
||||
db = await DatabaseService.create(request.app)
|
||||
|
||||
try:
|
||||
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
||||
results, token_usage = await embedding_search.search(query, limit, in_collection, distance_limit)
|
||||
except MediaWikiPageNotFoundException:
|
||||
error_msg = "Page \"%s\" not found." % page_title
|
||||
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "page-not-found",
|
||||
"title": page_title,
|
||||
"message": error_msg
|
||||
}, request=request, http_status=404)
|
||||
except MediaWikiApiException as e:
|
||||
error_msg = "MediaWiki API error: %s" % e.info
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "mediawiki-api-error",
|
||||
"info": e.info,
|
||||
"message": error_msg
|
||||
}, request=request, http_status=500)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "internal-server-error",
|
||||
"message": error_msg
|
||||
}, request=request, http_status=500)
|
||||
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)
|
@ -0,0 +1,36 @@
|
||||
from aiohttp import web
|
||||
import utils.web
|
||||
import utils.text
|
||||
from extend.hangul_romanize import Transliter
|
||||
from extend.hangul_romanize.rule import academic
|
||||
|
||||
class Hanja:
|
||||
@staticmethod
|
||||
def convertToRomaja(self, hanja: str):
|
||||
transliter = Transliter(academic)
|
||||
segList = utils.text.splitAscii(hanja)
|
||||
sentenceList = []
|
||||
for seg in segList:
|
||||
if seg == " ":
|
||||
sentenceList.append("-")
|
||||
elif utils.text.isAscii(seg):
|
||||
if utils.text.isAsciiPunc(seg):
|
||||
sentenceList.append(seg)
|
||||
else:
|
||||
sentenceList.append([seg])
|
||||
else:
|
||||
roma = transliter.translit(seg)
|
||||
sentenceList.append(roma.split(" "))
|
||||
return sentenceList
|
||||
|
||||
@staticmethod
|
||||
async def hanja2roma(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"sentence": {
|
||||
"required": True,
|
||||
},
|
||||
})
|
||||
sentence = params.get('sentence')
|
||||
|
||||
data = Hanja.convertToRomaja(sentence)
|
||||
return await utils.web.api_response(1, data, request=request)
|
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
import os.path as path
|
||||
import jieba
|
||||
import jieba.posseg as pseg
|
||||
from pypinyin import pinyin, Style
|
||||
import utils.text
|
||||
import utils.web
|
||||
|
||||
jieba.initialize()
|
||||
userDict = path.dirname(path.dirname(path.dirname(__file__))) + "/data/userDict.txt"
|
||||
if path.exists(userDict):
|
||||
jieba.load_userdict(userDict)
|
||||
|
||||
|
||||
class Hanzi:
|
||||
@staticmethod
|
||||
def filterJiebaTag(segList: list[str]):
|
||||
ret = []
|
||||
for word, flag in segList:
|
||||
if flag[0] == "u" and (word == "得" or word == "地"):
|
||||
ret.append("的")
|
||||
else:
|
||||
ret.append(word)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def convertToPinyin(sentence: str):
|
||||
sentence = utils.text.replaceCJKPunc(sentence).replace(' ', '-')
|
||||
segList = Hanzi.filterJiebaTag(pseg.cut(sentence))
|
||||
sentenceList = []
|
||||
pinyinGroup = []
|
||||
for seg in segList:
|
||||
if utils.text.isAscii(seg):
|
||||
if utils.text.isAsciiPunc(seg):
|
||||
if len(pinyinGroup) > 0:
|
||||
sentenceList.append(pinyinGroup)
|
||||
pinyinGroup = []
|
||||
sentenceList.append(seg)
|
||||
else:
|
||||
if len(pinyinGroup) > 0:
|
||||
sentenceList.append(pinyinGroup)
|
||||
pinyinGroup = []
|
||||
sentenceList.append([seg])
|
||||
else:
|
||||
sentencePinyin = []
|
||||
for one in pinyin(seg, style=Style.NORMAL):
|
||||
sentencePinyin.append(one[0])
|
||||
pinyinGroup.append(sentencePinyin)
|
||||
if len(pinyinGroup) > 0:
|
||||
sentenceList.append(pinyinGroup)
|
||||
|
||||
return sentenceList
|
||||
|
||||
@staticmethod
|
||||
async def hanziToPinyin(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"sentence": {
|
||||
"required": True,
|
||||
},
|
||||
})
|
||||
sentence = params.get('sentence')
|
||||
|
||||
data = Hanzi.convertToPinyin(sentence)
|
||||
return await utils.web.api_response(1, data, request=request)
|
||||
|
||||
@staticmethod
|
||||
async def splitHanzi(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"sentence": {
|
||||
"required": True,
|
||||
},
|
||||
})
|
||||
sentence = params.get("sentence")
|
||||
|
||||
segList = list(pseg.cut(sentence))
|
||||
data = []
|
||||
for word, flag in segList:
|
||||
data.append({"word": word, "flag": flag})
|
||||
return await utils.web.api_response(1, data)
|
@ -0,0 +1,302 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from aiohttp import web
|
||||
from sqlalchemy import select
|
||||
from api.model.toolkit_ui.conversation import ConversationHelper
|
||||
from api.model.toolkit_ui.page_title import PageTitleHelper
|
||||
from service.database import DatabaseService
|
||||
from service.event import EventService
|
||||
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
||||
import utils.web
|
||||
|
||||
|
||||
class Index:
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def update_title_info(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"title": {
|
||||
"required": True,
|
||||
}
|
||||
})
|
||||
|
||||
title = params.get("title")
|
||||
|
||||
mwapi = MediaWikiApi.create()
|
||||
db = await DatabaseService.create(request.app)
|
||||
async with PageTitleHelper(db) as page_title_helper:
|
||||
title_info = await page_title_helper.find_by_title(title)
|
||||
|
||||
if title_info is not None and time.time() - title_info.updated_at < 60:
|
||||
return await utils.web.api_response(1, {
|
||||
"cached": True,
|
||||
"title": title_info.title,
|
||||
"page_id": title_info.page_id,
|
||||
}, request=request)
|
||||
|
||||
# Load page info from MediaWiki API
|
||||
try:
|
||||
page_info = await mwapi.get_page_info(title)
|
||||
page_id = page_info.get("pageid")
|
||||
real_title = page_info.get("title")
|
||||
|
||||
if title_info is None:
|
||||
title_info = await page_title_helper.add(page_id, real_title)
|
||||
else:
|
||||
title_info.page_id = page_id
|
||||
title_info.title = real_title
|
||||
await page_title_helper.update(title_info)
|
||||
|
||||
return await utils.web.api_response(1, {
|
||||
"cached": False,
|
||||
"title": real_title,
|
||||
"page_id": page_id
|
||||
}, request=request)
|
||||
except MediaWikiPageNotFoundException:
|
||||
error_msg = "Page \"%s\" not found." % title
|
||||
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "page-not-found",
|
||||
"message": error_msg
|
||||
}, request=request, http_status=404)
|
||||
except MediaWikiApiException as e:
|
||||
error_msg = "MediaWiki API error: %s" % e.info
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "mediawiki-api-error",
|
||||
"message": error_msg
|
||||
}, request=request, http_status=500)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "internal-server-error",
|
||||
"message": error_msg
|
||||
}, request=request, http_status=500)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def get_conversation_list(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"user_id": {
|
||||
"required": False,
|
||||
"type": int
|
||||
},
|
||||
"title": {
|
||||
"required": True,
|
||||
},
|
||||
"module": {
|
||||
"required": False
|
||||
}
|
||||
})
|
||||
|
||||
if request.get("caller") == "user":
|
||||
user_id = request.get("user")
|
||||
else:
|
||||
user_id = params.get("user_id")
|
||||
|
||||
page_title = params.get("title")
|
||||
module = params.get("module")
|
||||
|
||||
db = await DatabaseService.create(request.app)
|
||||
async with PageTitleHelper(db) as page_title_helper, ConversationHelper(db) as conversation_helper:
|
||||
page_id = await page_title_helper.get_page_id_by_title(page_title)
|
||||
if page_id is None:
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "page-not-found",
|
||||
"message": "Page not found.",
|
||||
}, request=request, http_status=404)
|
||||
|
||||
conversation_list = await conversation_helper.get_conversation_list(user_id, module=module, page_id=page_id)
|
||||
|
||||
conversation_result = []
|
||||
|
||||
for result in conversation_list:
|
||||
conversation_result.append({
|
||||
"id": result.id,
|
||||
"module": result.module,
|
||||
"title": result.title,
|
||||
"thumbnail": result.thumbnail,
|
||||
"rev_id": result.rev_id,
|
||||
"updated_at": result.updated_at,
|
||||
"pinned": result.pinned,
|
||||
"extra": result.extra,
|
||||
})
|
||||
|
||||
return await utils.web.api_response(1, {
|
||||
"conversations": conversation_result
|
||||
}, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def get_conversation_info(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"id": {
|
||||
"required": True,
|
||||
"type": int
|
||||
}
|
||||
})
|
||||
|
||||
conversation_id = params.get("id")
|
||||
|
||||
db = await DatabaseService.create(request.app)
|
||||
async with ConversationHelper(db) as conversation_helper:
|
||||
conversation_info = await conversation_helper.find_by_id(conversation_id)
|
||||
|
||||
if conversation_info is None:
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "conversation-not-found",
|
||||
"message": "Conversation not found.",
|
||||
}, request=request, http_status=404)
|
||||
|
||||
if request.get("caller") == "user" and int(request.get("user")) != conversation_info.user_id:
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "permission-denied",
|
||||
"message": "Permission denied."
|
||||
}, request=request, http_status=403)
|
||||
|
||||
conversation_result = {
|
||||
"id": conversation_info.id,
|
||||
"module": conversation_info.module,
|
||||
"title": conversation_info.title,
|
||||
"thumbnail": conversation_info.thumbnail,
|
||||
"rev_id": conversation_info.rev_id,
|
||||
"updated_at": conversation_info.updated_at,
|
||||
"pinned": conversation_info.pinned,
|
||||
"extra": conversation_info.extra,
|
||||
}
|
||||
|
||||
return await utils.web.api_response(1, conversation_result, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def remove_conversation(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"id": {
|
||||
"required": True,
|
||||
"type": int
|
||||
}
|
||||
})
|
||||
|
||||
conversation_id = params.get("id")
|
||||
|
||||
db = await DatabaseService.create(request.app)
|
||||
async with ConversationHelper(db) as conversation_helper:
|
||||
conversation_info = await conversation_helper.find_by_id(conversation_id)
|
||||
|
||||
if conversation_info is None:
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "conversation-not-found",
|
||||
"message": "Conversation not found."
|
||||
}, request=request, http_status=404)
|
||||
|
||||
if request.get("caller") == "user" and int(request.get("user")) != conversation_info.user_id:
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "permission-denied",
|
||||
"message": "Permission denied."
|
||||
}, request=request, http_status=403)
|
||||
|
||||
await conversation_helper.remove(conversation_info)
|
||||
|
||||
# 通知其他模块删除
|
||||
events = EventService.create()
|
||||
events.emit("conversation/removed", {
|
||||
"conversation": conversation_info,
|
||||
"dbs": db,
|
||||
"app": request.app,
|
||||
})
|
||||
events.emit("conversation/removed/" + conversation_info.module, {
|
||||
"conversation": conversation_info,
|
||||
"dbs": db,
|
||||
"app": request.app,
|
||||
})
|
||||
|
||||
return await utils.web.api_response(1, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def set_conversation_pinned(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"id": {
|
||||
"required": True,
|
||||
"type": int
|
||||
},
|
||||
"pinned": {
|
||||
"required": True,
|
||||
"type": bool
|
||||
}
|
||||
})
|
||||
|
||||
conversation_id = params.get("id")
|
||||
pinned = params.get("pinned")
|
||||
|
||||
db = await DatabaseService.create(request.app)
|
||||
async with ConversationHelper(db) as conversation_helper:
|
||||
conversation_info = await conversation_helper.find_by_id(conversation_id)
|
||||
|
||||
if conversation_info is None:
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "conversation-not-found",
|
||||
"message": "Conversation not found."
|
||||
}, request=request, http_status=404)
|
||||
|
||||
if request.get("caller") == "user" and int(request.get("user")) != conversation_info.user_id:
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "permission-denied",
|
||||
"message": "Permission denied."
|
||||
}, request=request, http_status=403)
|
||||
|
||||
conversation_info.pinned = pinned
|
||||
await conversation_helper.update(conversation_info)
|
||||
|
||||
return await utils.web.api_response(1, request=request)
|
||||
|
||||
@staticmethod
|
||||
@utils.web.token_auth
|
||||
async def get_user_info(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"user_id": {
|
||||
"required": False,
|
||||
"type": int
|
||||
}
|
||||
})
|
||||
|
||||
if request.get("caller") == "user":
|
||||
user_id = request.get("user")
|
||||
else:
|
||||
user_id = params.get("user_id")
|
||||
|
||||
mwapi = MediaWikiApi.create()
|
||||
|
||||
try:
|
||||
user_info = await mwapi.chat_complete_user_info(user_id)
|
||||
return await utils.web.api_response(1, user_info, request=request)
|
||||
except MediaWikiPageNotFoundException as e:
|
||||
return await utils.web.api_response(-2, error={
|
||||
"code": "user-not-found",
|
||||
"message": "User not found."
|
||||
}, request=request, http_status=403)
|
||||
except MediaWikiApiException as e:
|
||||
err_str = "MediaWiki API error: %s" % e.info
|
||||
print(err_str, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return await utils.web.api_response(-3, error={
|
||||
"code": "mediawiki-api-error",
|
||||
"info": e.info,
|
||||
"message": err_str
|
||||
}, request=request, http_status=500)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
print(err_str, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return await utils.web.api_response(-1, error={
|
||||
"code": "internal-server-error",
|
||||
"message": err_str
|
||||
}, request=request, http_status=500)
|
@ -0,0 +1,32 @@
|
||||
from aiohttp import web
|
||||
import utils.web
|
||||
import utils.text
|
||||
from extend.kanji_to_romaji import kanji_to_romaji
|
||||
|
||||
class Kanji:
|
||||
@staticmethod
|
||||
def convertToRomaji(self, kanji: str):
|
||||
segList = utils.text.splitAscii(kanji)
|
||||
sentenceList = []
|
||||
for seg in segList:
|
||||
if utils.text.isAscii(seg):
|
||||
if utils.text.isAsciiPunc(seg):
|
||||
sentenceList.append(seg)
|
||||
else:
|
||||
sentenceList.append([seg])
|
||||
else:
|
||||
romaji = kanji_to_romaji(seg)
|
||||
sentenceList.append(romaji.split(" "))
|
||||
return sentenceList
|
||||
|
||||
@staticmethod
|
||||
async def kanji2romaji(request: web.Request):
|
||||
params = await utils.web.get_param(request, {
|
||||
"sentence": {
|
||||
"required": True,
|
||||
},
|
||||
})
|
||||
sentence = params.get('sentence')
|
||||
|
||||
data = Kanji.convertToRomaji(sentence)
|
||||
return await utils.web.api_response(1, data, request=request)
|
@ -0,0 +1,4 @@
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
class BaseModel(DeclarativeBase):
|
||||
pass
|
@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
||||
|
||||
from api.model.base import BaseModel
|
||||
from api.model.toolkit_ui.conversation import ConversationModel
|
||||
from service.database import DatabaseService
|
||||
from service.event import EventService
|
||||
|
||||
class ConversationChunkModel(BaseModel):
|
||||
__tablename__ = "chat_complete_conversation_chunk"
|
||||
|
||||
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("chat_complete_conversation.id"), index=True)
|
||||
message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
|
||||
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
|
||||
updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True)
|
||||
|
||||
class ConversationChunkHelper:
|
||||
def __init__(self, dbs: DatabaseService):
|
||||
self.dbs = dbs
|
||||
self.initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.initialized:
|
||||
self.create_session = self.dbs.create_session
|
||||
self.session = self.dbs.create_session()
|
||||
await self.session.__aenter__()
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.session.__aexit__(exc_type, exc, tb)
|
||||
pass
|
||||
|
||||
async def add(self, conversation_id: int, message_data: list, tokens: int):
|
||||
async with self.create_session() as session:
|
||||
chunk = ConversationChunkModel(
|
||||
conversation_id=conversation_id,
|
||||
message_data=message_data,
|
||||
tokens=tokens,
|
||||
updated_at=sqlalchemy.func.current_timestamp()
|
||||
)
|
||||
session.add(chunk)
|
||||
await session.commit()
|
||||
await session.refresh(chunk)
|
||||
return chunk
|
||||
|
||||
async def update(self, chunk: ConversationChunkModel):
|
||||
chunk.updated_at = sqlalchemy.func.current_timestamp()
|
||||
chunk = await self.session.merge(chunk)
|
||||
await self.session.commit()
|
||||
return chunk
|
||||
|
||||
async def update_message_log(self, chunk_id: int, message_data: list, tokens: int):
|
||||
stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \
|
||||
.values(message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp())
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def get_newest_chunk(self, conversation_id: int):
|
||||
stmt = sqlalchemy.select(ConversationChunkModel) \
|
||||
.where(ConversationChunkModel.conversation_id == conversation_id) \
|
||||
.order_by(ConversationChunkModel.id.desc()) \
|
||||
.limit(1)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
async def remove(self, id: int):
|
||||
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def remove_by_conversation_id(self, conversation_id: int):
|
||||
stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id == conversation_id)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def on_conversation_removed(event):
|
||||
if "conversation" in event:
|
||||
conversation_info = event["conversation"]
|
||||
conversation_id = conversation_info["id"]
|
||||
await ConversationChunkHelper(event["dbs"]).remove_by_conversation_id(conversation_id)
|
||||
|
||||
EventService.create().add_listener("conversation/removed/chatcomplete", on_conversation_removed)
|
@ -0,0 +1,292 @@
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
import asyncpg
|
||||
from api.model.base import BaseModel
|
||||
import config
|
||||
import numpy as np
|
||||
import sqlalchemy
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.orm import mapped_column, Mapped
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pgvector.asyncpg import register_vector
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
from service.database import DatabaseService
|
||||
|
||||
class PageIndexModel(BaseModel):
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
|
||||
text: Mapped[str] = mapped_column(sqlalchemy.Text)
|
||||
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
|
||||
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
|
||||
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)
|
||||
markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
|
||||
temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
|
||||
|
||||
class PageIndexHelper:
|
||||
columns = [
|
||||
"id",
|
||||
"page_id"
|
||||
"sha1",
|
||||
"text",
|
||||
"text_len",
|
||||
"markdown",
|
||||
"markdown_len",
|
||||
"embedding",
|
||||
"temp_doc_session_id"
|
||||
]
|
||||
|
||||
def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]):
|
||||
self.dbs = dbs
|
||||
self.collection_id = collection_id
|
||||
self.page_id = page_id if page_id is not None else -1
|
||||
self.table_name = "embedding_search_page_index_%s" % str(collection_id)
|
||||
self.initialized = False
|
||||
self.table_initialized = False
|
||||
|
||||
"""
|
||||
Initialize table
|
||||
"""
|
||||
async def __aenter__(self):
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.dbpool = self.dbs.pool.acquire()
|
||||
self.dbi = await self.dbpool.__aenter__()
|
||||
|
||||
await register_vector(self.dbi)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.dbpool.__aexit__(exc_type, exc, tb)
|
||||
|
||||
async def table_exists(self):
|
||||
exists = await self.dbi.fetchval("""SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
);""", self.table_name, column=0)
|
||||
|
||||
return bool(exists)
|
||||
|
||||
async def init_table(self):
|
||||
if self.table_initialized:
|
||||
return
|
||||
|
||||
# create table if not exists
|
||||
if not await self.table_exists():
|
||||
await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
|
||||
id SERIAL PRIMARY KEY,
|
||||
page_id INTEGER NOT NULL,
|
||||
sha1 VARCHAR(40) NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
text_len INTEGER NOT NULL,
|
||||
embedding VECTOR(%d) NOT NULL,
|
||||
markdown TEXT NULL,
|
||||
markdown_len INTEGER NULL,
|
||||
temp_doc_session_id INTEGER NULL
|
||||
);
|
||||
CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
|
||||
CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
|
||||
CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
|
||||
""" % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
|
||||
|
||||
self.table_initialized = False
|
||||
|
||||
async def create_embedding_index(self):
|
||||
await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
|
||||
.replace("/*_*/", self.table_name))
|
||||
|
||||
def sha1_doc(self, doc: list):
|
||||
for item in doc:
|
||||
if "sha1" not in item or not item["sha1"]:
|
||||
sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
|
||||
item["sha1"] = sha1
|
||||
|
||||
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
|
||||
indexed_sha1_list = []
|
||||
sql = "SELECT sha1 FROM %s" % (self.table_name)
|
||||
|
||||
where = []
|
||||
params = []
|
||||
|
||||
if not with_temporary:
|
||||
where.append("temp_doc_session_id IS NULL")
|
||||
|
||||
if not in_collection:
|
||||
params.append(self.page_id)
|
||||
where.append("page_id = $%d" % len(params))
|
||||
|
||||
if len(where) > 0:
|
||||
sql += " WHERE " + (" AND ".join(where))
|
||||
|
||||
ret = await self.dbi.fetch(sql, *params)
|
||||
|
||||
for row in ret:
|
||||
indexed_sha1_list.append(row[0])
|
||||
return indexed_sha1_list
|
||||
|
||||
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):
|
||||
indexed_sha1_list = await self.get_indexed_sha1(with_temporary)
|
||||
self.sha1_doc(doc)
|
||||
|
||||
should_index = []
|
||||
for item in doc:
|
||||
if item["sha1"] not in indexed_sha1_list:
|
||||
should_index.append(item)
|
||||
|
||||
return should_index
|
||||
|
||||
async def remove_outdated_doc(self, doc: list):
|
||||
await self.clear_temp()
|
||||
|
||||
indexed_sha1_list = await self.get_indexed_sha1(False)
|
||||
self.sha1_doc(doc)
|
||||
|
||||
doc_sha1_list = [item["sha1"] for item in doc]
|
||||
|
||||
should_remove = []
|
||||
for sha1 in indexed_sha1_list:
|
||||
if sha1 not in doc_sha1_list:
|
||||
should_remove.append(sha1)
|
||||
|
||||
if len(should_remove) > 0:
|
||||
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name),
|
||||
self.page_id, should_remove)
|
||||
|
||||
async def index_doc(self, doc: list):
|
||||
need_create_index = False
|
||||
|
||||
indexed_persist_sha1_list = []
|
||||
indexed_temp_sha1_list = []
|
||||
|
||||
ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name),
|
||||
self.page_id)
|
||||
for row in ret:
|
||||
if row[1]:
|
||||
indexed_temp_sha1_list.append(row[0])
|
||||
else:
|
||||
indexed_persist_sha1_list.append(row[0])
|
||||
|
||||
# Create index when no indexed document
|
||||
if len(indexed_persist_sha1_list) == 0:
|
||||
need_create_index = True
|
||||
|
||||
self.sha1_doc(doc)
|
||||
|
||||
doc_sha1_list = []
|
||||
|
||||
should_index = []
|
||||
should_persist = []
|
||||
should_remove = []
|
||||
for item in doc:
|
||||
doc_sha1_list.append(item["sha1"])
|
||||
|
||||
if item["sha1"] in indexed_temp_sha1_list:
|
||||
should_persist.append(item["sha1"])
|
||||
elif item["sha1"] not in indexed_persist_sha1_list:
|
||||
should_index.append(item)
|
||||
|
||||
for sha1 in indexed_persist_sha1_list:
|
||||
if sha1 not in doc_sha1_list:
|
||||
should_remove.append(sha1)
|
||||
|
||||
if len(should_index) > 0:
|
||||
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name),
|
||||
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index])
|
||||
|
||||
if len(should_persist) > 0:
|
||||
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name),
|
||||
[(self.page_id, sha1) for sha1 in should_persist])
|
||||
|
||||
if need_create_index:
|
||||
await self.create_embedding_index()
|
||||
|
||||
"""
|
||||
Add temporary document to the index
|
||||
"""
|
||||
async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
|
||||
indexed_sha1_list = []
|
||||
indexed_temp_sha1_list = []
|
||||
doc_sha1_list = []
|
||||
|
||||
sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % (
|
||||
self.table_name)
|
||||
ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id)
|
||||
for row in ret:
|
||||
indexed_sha1_list.append(row[0])
|
||||
if row[1]:
|
||||
indexed_temp_sha1_list.append(row[0])
|
||||
|
||||
self.sha1_doc(doc)
|
||||
|
||||
should_index = []
|
||||
should_remove = []
|
||||
|
||||
for item in doc:
|
||||
sha1 = item["sha1"]
|
||||
doc_sha1_list.append(sha1)
|
||||
if sha1 not in indexed_sha1_list:
|
||||
should_index.append(item)
|
||||
|
||||
for sha1 in indexed_temp_sha1_list:
|
||||
if sha1 not in doc_sha1_list:
|
||||
should_remove.append(sha1)
|
||||
|
||||
if len(should_index) > 0:
|
||||
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name),
|
||||
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index])
|
||||
|
||||
if len(should_remove) > 0:
|
||||
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name),
|
||||
self.page_id, temp_doc_session_id, should_remove)
|
||||
|
||||
"""
|
||||
Search for text by consine similary
|
||||
"""
|
||||
async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10):
|
||||
if in_collection:
|
||||
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
||||
FROM %s
|
||||
ORDER BY distance ASC
|
||||
LIMIT %d""" % (self.table_name, limit), embedding)
|
||||
else:
|
||||
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
||||
FROM %s
|
||||
WHERE page_id = $2
|
||||
ORDER BY distance ASC
|
||||
LIMIT %d""" % (self.table_name, limit), embedding, self.page_id)
|
||||
|
||||
"""
|
||||
Clear temporary index
|
||||
"""
|
||||
async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None):
|
||||
sql = "DELETE FROM %s" % (self.table_name)
|
||||
|
||||
where = []
|
||||
params = []
|
||||
|
||||
if not in_collection:
|
||||
params.append(self.page_id)
|
||||
where.append("page_id = $%d" % len(params))
|
||||
|
||||
if temp_doc_session_id:
|
||||
params.append(temp_doc_session_id)
|
||||
where.append("temp_doc_session_id = $%d" % len(params))
|
||||
else:
|
||||
where.append("temp_doc_session_id IS NOT NULL")
|
||||
|
||||
if len(where) > 0:
|
||||
sql += " WHERE " + (" AND ".join(where))
|
||||
|
||||
await self.dbi.execute(sql, *params)
|
@ -0,0 +1,63 @@
|
||||
from typing import Optional, Union
|
||||
import sqlalchemy
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.orm import mapped_column, Mapped
|
||||
|
||||
from api.model.base import BaseModel
|
||||
from service.database import DatabaseService
|
||||
|
||||
class TitleCollectionModel(BaseModel):
|
||||
__tablename__ = "embedding_search_title_collection"
|
||||
|
||||
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
|
||||
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
|
||||
class TitleCollectionHelper:
|
||||
def __init__(self, dbs: DatabaseService):
|
||||
self.dbs = dbs
|
||||
self.initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.initialized:
|
||||
self.create_session = self.dbs.create_session
|
||||
self.session = self.dbs.create_session()
|
||||
await self.session.__aenter__()
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.session.__aexit__(exc_type, exc, tb)
|
||||
pass
|
||||
|
||||
async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]:
|
||||
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)
|
||||
result = await self.session.scalar(stmt)
|
||||
|
||||
if result is None:
|
||||
obj = TitleCollectionModel(title=title, page_id=page_id)
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj.id
|
||||
|
||||
return False
|
||||
|
||||
async def set_page_id(self, title: str, page_id: Optional[str] = None):
|
||||
stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def remove(self, title: str):
|
||||
stmt = delete(TitleCollectionModel).where(TitleCollectionModel.title == title)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def find_by_title(self, title: str):
|
||||
stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
async def find_by_page_id(self, page_id: int):
|
||||
stmt = select(TitleCollectionModel).where(TitleCollectionModel.page_id == page_id)
|
||||
return await self.session.scalar(stmt)
|
@ -0,0 +1,158 @@
|
||||
import hashlib
|
||||
import asyncpg
|
||||
import numpy as np
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from pgvector.asyncpg import register_vector
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
import config
|
||||
from api.model.base import BaseModel
|
||||
from service.database import DatabaseService
|
||||
|
||||
class TitleIndexModel(BaseModel):
|
||||
__tablename__ = "embedding_search_title_index"
|
||||
|
||||
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
sha1: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
|
||||
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
|
||||
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), index=True)
|
||||
|
||||
class TitleIndexHelper:
|
||||
__tablename__ = "embedding_search_title_index"
|
||||
|
||||
columns = [
|
||||
"id",
|
||||
"sha1",
|
||||
"title",
|
||||
"page_id",
|
||||
"collection_id",
|
||||
"rev_id",
|
||||
"embedding",
|
||||
]
|
||||
|
||||
def __init__(self, dbs: DatabaseService):
|
||||
self.dbs = dbs
|
||||
self.initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.initialized:
|
||||
self.dbpool = self.dbs.pool.acquire()
|
||||
self.dbi = await self.dbpool.__aenter__()
|
||||
|
||||
await register_vector(self.dbi)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.dbpool.__aexit__(exc_type, exc, tb)
|
||||
|
||||
def get_columns(self, exclude=[]):
|
||||
if len(exclude) == 0:
|
||||
return ", ".join(self.columns)
|
||||
|
||||
return ", ".join([col for col in self.columns if col not in exclude])
|
||||
|
||||
"""
|
||||
Add a title to the index
|
||||
"""
|
||||
async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray):
|
||||
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
||||
ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
|
||||
|
||||
if ret is None:
|
||||
new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index
|
||||
(sha1, title, page_id, rev_id, collection_id, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id""",
|
||||
title_sha1, title, page_id, rev_id, collection_id, embedding, column=0)
|
||||
return new_id
|
||||
|
||||
return False
|
||||
|
||||
"""
|
||||
Remove a title from the index
|
||||
"""
|
||||
async def remove(self, title: str):
|
||||
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
||||
await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
|
||||
|
||||
"""
|
||||
Update the indexed revision id of a title
|
||||
"""
|
||||
async def update_rev_id(self, page_id: int, rev_id: int):
|
||||
await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id)
|
||||
|
||||
"""
|
||||
Update title data
|
||||
"""
|
||||
async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray):
|
||||
if collection_page_id is None:
|
||||
collection_page_id = page_id
|
||||
|
||||
await self.dbi.execute("""UPDATE embedding_search_title_index
|
||||
SET title = $1, rev_id = $2, collection_id = $3, embedding = $4
|
||||
WHERE page_id = $5""",
|
||||
title, rev_id, collection_id, embedding, page_id)
|
||||
|
||||
"""
|
||||
Search for titles by consine similary
|
||||
"""
|
||||
async def search_title_embedding(self, embedding: np.ndarray, limit: int = 10):
|
||||
ret = self.dbi.fetch("""SELECT %s, embedding <-> $1 AS distance
|
||||
FROM embedding_search_title_index
|
||||
ORDER BY distance DESC
|
||||
LIMIT %d""" % (self.get_columns(exclude=['embedding']), limit),
|
||||
embedding)
|
||||
|
||||
return ret
|
||||
|
||||
"""
|
||||
Find a title in the index
|
||||
"""
|
||||
async def find_by_title(self, title: str, with_embedding=False):
|
||||
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
||||
|
||||
if with_embedding:
|
||||
columns = self.get_columns()
|
||||
else:
|
||||
columns = self.get_columns(exclude=["embedding"])
|
||||
|
||||
ret = await self.dbi.fetchrow(
|
||||
"SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns,
|
||||
title_sha1
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
async def find_by_page_id(self, page_id: int, with_embedding=False):
|
||||
if with_embedding:
|
||||
columns = self.get_columns()
|
||||
else:
|
||||
columns = self.get_columns(exclude=["embedding"])
|
||||
|
||||
ret = await self.dbi.fetchrow(
|
||||
"SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns,
|
||||
page_id
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
async def find_by_collection_id(self, collection_id: int, with_embedding=False):
|
||||
if with_embedding:
|
||||
columns = self.get_columns()
|
||||
else:
|
||||
columns = self.get_columns(exclude=["embedding"])
|
||||
|
||||
ret = await self.dbi.fetch(
|
||||
"SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns,
|
||||
collection_id
|
||||
)
|
||||
|
||||
return ret
|
@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
import sqlalchemy
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import mapped_column, Mapped
|
||||
|
||||
from api.model.base import BaseModel
|
||||
from service.database import DatabaseService
|
||||
|
||||
|
||||
class ConversationModel(BaseModel):
|
||||
__tablename__ = "toolkit_ui_conversation"
|
||||
|
||||
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
module: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||||
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
|
||||
thumbnail: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True)
|
||||
page_id: Mapped[int] = mapped_column(
|
||||
sqlalchemy.Integer, index=True, nullable=True)
|
||||
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
|
||||
pinned: Mapped[bool] = mapped_column(
|
||||
sqlalchemy.Boolean, default=False, index=True)
|
||||
extra: Mapped[dict] = mapped_column(sqlalchemy.JSON, default={})
|
||||
|
||||
|
||||
class ConversationHelper:
|
||||
def __init__(self, dbs: DatabaseService):
|
||||
self.dbs = dbs
|
||||
self.initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.initialized:
|
||||
self.create_session = self.dbs.create_session
|
||||
self.session = self.dbs.create_session()
|
||||
await self.session.__aenter__()
|
||||
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.session.__aexit__(exc_type, exc, tb)
|
||||
pass
|
||||
|
||||
async def add(self, user_id: int, module: str, title: Optional[str] = None, page_id: Optional[int] = None, rev_id: Optional[int] = None, extra: Optional[dict] = None):
|
||||
obj = ConversationModel(user_id=user_id, module=module, title=title,
|
||||
page_id=page_id, rev_id=rev_id, updated_at=sqlalchemy.func.current_timestamp())
|
||||
|
||||
if extra is not None:
|
||||
obj.extra = extra
|
||||
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def refresh_updated_at(self, conversation_id: int):
|
||||
stmt = update(ConversationModel).where(ConversationModel.id ==
|
||||
conversation_id).values(updated_at=sqlalchemy.func.current_timestamp())
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def update(self, obj: ConversationModel):
|
||||
await self.session.merge(obj)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def get_conversation_list(self, user_id: int, module: Optional[str] = None, page_id: Optional[int] = None) -> List[ConversationModel]:
|
||||
stmt = sqlalchemy.select(ConversationModel) \
|
||||
.where(ConversationModel.user_id == user_id)
|
||||
|
||||
if module is not None:
|
||||
stmt = stmt.where(ConversationModel.module == module)
|
||||
|
||||
if page_id is not None:
|
||||
stmt = stmt.where(ConversationModel.page_id == page_id)
|
||||
|
||||
stmt = stmt.order_by(ConversationModel.pinned.desc(),
|
||||
ConversationModel.updated_at.desc())
|
||||
|
||||
return await self.session.scalars(stmt)
|
||||
|
||||
async def find_by_id(self, conversation_id: int):
|
||||
async with self.create_session() as session:
|
||||
stmt = sqlalchemy.select(ConversationModel).where(
|
||||
ConversationModel.id == conversation_id)
|
||||
return await session.scalar(stmt)
|
||||
|
||||
async def remove(self, conversation_id: int):
|
||||
stmt = sqlalchemy.delete(ConversationModel).where(
|
||||
ConversationModel.id == conversation_id)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
import datetime
|
||||
|
||||
from typing import Optional
|
||||
import sqlalchemy
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import mapped_column, Mapped
|
||||
|
||||
from api.model.base import BaseModel
|
||||
from service.database import DatabaseService
|
||||
|
||||
|
||||
class PageTitleModel(BaseModel):
|
||||
__tablename__ = "toolkit_ui_page_title"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
||||
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
|
||||
|
||||
|
||||
class PageTitleHelper:
|
||||
def __init__(self, dbs: DatabaseService):
|
||||
self.dbs = dbs
|
||||
self.initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.initialized:
|
||||
self.create_session = self.dbs.create_session
|
||||
self.session = self.dbs.create_session()
|
||||
await self.session.__aenter__()
|
||||
|
||||
self.initialized = True
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.session.__aexit__(exc_type, exc, tb)
|
||||
pass
|
||||
|
||||
async def find_by_page_id(self, page_id: int):
|
||||
stmt = select(PageTitleModel).where(PageTitleModel.page_id == page_id)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
async def find_by_title(self, title: str):
|
||||
stmt = select(PageTitleModel).where(PageTitleModel.title == title)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
async def get_page_id_by_title(self, title: str):
|
||||
obj = await self.find_by_title(title)
|
||||
if obj is None:
|
||||
return None
|
||||
return obj.page_id
|
||||
|
||||
async def should_update(self, title: str):
|
||||
title_info = await self.find_by_title(title)
|
||||
if title_info is None:
|
||||
return True
|
||||
if title_info.updated_at < (datetime.now() - datetime.timedelta(days=7)):
|
||||
return True
|
||||
|
||||
async def add(self, page_id: int, title: Optional[str] = None):
|
||||
obj = PageTitleModel(page_id=page_id, title=title, updated_at=sqlalchemy.func.current_timestamp())
|
||||
|
||||
self.session.add(obj)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def set_title(self, page_id: int, title: Optional[str] = None):
|
||||
stmt = update(PageTitleModel).where(
|
||||
PageTitleModel.page_id == page_id).values(title=title, updated_at=sqlalchemy.func.current_timestamp())
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def update(self, obj: PageTitleModel, ignore_updated_at: bool = False):
|
||||
self.session.merge(obj)
|
||||
if not ignore_updated_at:
|
||||
obj.updated_at = sqlalchemy.func.current_timestamp()
|
||||
await self.session.commit()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
@ -0,0 +1,33 @@
|
||||
from aiohttp import web
|
||||
from api.controller.ChatComplete import ChatComplete
|
||||
|
||||
from api.controller.Hanzi import Hanzi
|
||||
from api.controller.Index import Index
|
||||
from api.controller.Kanji import Kanji
|
||||
from api.controller.Hanja import Hanja
|
||||
from api.controller.EmbeddingSearch import EmbeddingSearch
|
||||
|
||||
def init(app: web.Application):
|
||||
app.router.add_routes([
|
||||
web.route('*', '/hanzi/pinyin/', Hanzi.hanziToPinyin),
|
||||
web.route('*', '/hanzi/split/', Hanzi.splitHanzi),
|
||||
|
||||
web.route('*', '/kanji/romaji/', Kanji.kanji2romaji),
|
||||
|
||||
web.route('*', '/hanja/romaja/', Hanja.hanja2roma),
|
||||
|
||||
web.route('*', '/title/info', Index.update_title_info),
|
||||
web.route('*', '/user/info', Index.get_user_info),
|
||||
web.route('*', '/conversations', Index.get_conversation_list),
|
||||
web.route('*', '/conversation/info', Index.get_conversation_info),
|
||||
web.route('POST', '/conversation/remove', Index.remove_conversation),
|
||||
web.route('DELETE', '/conversation/remove', Index.remove_conversation),
|
||||
web.route('POST', '/conversation/set_pinned', Index.set_conversation_pinned),
|
||||
|
||||
web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page),
|
||||
web.route('*', '/embedding_search/search', EmbeddingSearch.search),
|
||||
|
||||
web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list),
|
||||
web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk),
|
||||
web.route('*', '/chatcomplete/message', ChatComplete.chat_complete),
|
||||
])
|
@ -0,0 +1,10 @@
|
||||
异世界 100 n
|
||||
克苏鲁 20 n
|
||||
恐怖谷 20 n
|
||||
扶她 20 n
|
||||
汉山 20 n
|
||||
明美 20 n
|
||||
驱魔 20 n
|
||||
驱魔人 20 n
|
||||
轻小说 2000 n
|
||||
曦月 20 n
|
@ -0,0 +1,2 @@
|
||||
|
||||
from .core import Transliter # noqa
|
@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
try:
|
||||
unicode(0)
|
||||
except NameError:
|
||||
# py3
|
||||
unicode = str
|
||||
unichr = chr
|
||||
|
||||
|
||||
class Syllable(object):
|
||||
"""Hangul syllable interface"""
|
||||
|
||||
MIN = ord('가')
|
||||
MAX = ord('힣')
|
||||
|
||||
def __init__(self, char=None, code=None):
|
||||
if char is None and code is None:
|
||||
raise TypeError('__init__ takes char or code as a keyword argument (not given)')
|
||||
if char is not None and code is not None:
|
||||
raise TypeError('__init__ takes char or code as a keyword argument (both given)')
|
||||
if char:
|
||||
code = ord(char)
|
||||
if not self.MIN <= code <= self.MAX:
|
||||
raise TypeError('__init__ expected Hangul syllable but {0} not in [{1}..{2}]'.format(code, self.MIN, self.MAX))
|
||||
self.code = code
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self.code - self.MIN
|
||||
|
||||
@property
|
||||
def initial(self):
|
||||
return self.index // 588
|
||||
|
||||
@property
|
||||
def vowel(self):
|
||||
return (self.index // 28) % 21
|
||||
|
||||
@property
|
||||
def final(self):
|
||||
return self.index % 28
|
||||
|
||||
@property
|
||||
def char(self):
|
||||
return unichr(self.code)
|
||||
|
||||
def __unicode__(self):
|
||||
return self.char
|
||||
|
||||
def __repr__(self):
|
||||
return '''<Syllable({}({}),{}({}),{}({}),{}({}))>'''.format(
|
||||
self.code, self.char, self.initial, '', self.vowel, '', self.final, '')
|
||||
|
||||
|
||||
class Transliter(object):
|
||||
"""General transliting interface"""
|
||||
|
||||
def __init__(self, rule):
|
||||
self.rule = rule
|
||||
|
||||
def translit(self, text):
|
||||
"""Translit text to romanized text
|
||||
|
||||
:param text: Unicode string or unicode character iterator
|
||||
"""
|
||||
result = []
|
||||
pre = None, None
|
||||
now = None, None
|
||||
for c in text:
|
||||
try:
|
||||
post = c, Syllable(c)
|
||||
except TypeError:
|
||||
post = c, None
|
||||
|
||||
if now[0] is not None:
|
||||
out = self.rule(now, pre=pre, post=post)
|
||||
if out is not None:
|
||||
result.append(out)
|
||||
|
||||
pre = now
|
||||
now = post
|
||||
|
||||
if now is not None:
|
||||
out = self.rule(now, pre=pre, post=(None, None))
|
||||
if out is not None:
|
||||
result.append(out)
|
||||
|
||||
return ''.join(result)
|
@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
REVISED_INITIALS = 'g', 'kk', 'n', 'd', 'tt', 'l', 'm', 'b', 'pp', 's', 'ss', '', 'j', 'jj', 'ch', 'k', 't', 'p', 'h'
|
||||
REVISED_VOWELS = 'a', 'ae', 'ya', 'yae', 'eo', 'e', 'yeo', 'ye', 'o', 'wa', 'wae', 'oe', 'yo', 'u', 'wo', 'we', 'wi', 'yu', 'eu', 'ui', 'i'
|
||||
REVISED_FINALS = '', 'g', 'kk', 'gs', 'n', 'nj', 'nh', 'd', 'l', 'lg', 'lm', 'lb', 'ls', 'lt', 'lp', 'lh', 'm', 'b', 'bs', 's', 'ss', 'ng', 'j', 'ch', 'k', 't', 'p', 'h'
|
||||
|
||||
|
||||
def academic_ambiguous_patterns():
|
||||
import itertools
|
||||
result = set()
|
||||
for final, initial in itertools.product(REVISED_FINALS, REVISED_INITIALS):
|
||||
check = False
|
||||
combined = final + initial
|
||||
for i in range(len(combined)):
|
||||
head, tail = combined[:i], combined[i:]
|
||||
if head in REVISED_FINALS and tail in REVISED_INITIALS:
|
||||
if not check:
|
||||
check = True
|
||||
else:
|
||||
result.add(combined)
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
ACADEMIC_AMBIGUOUS_PATTERNS = academic_ambiguous_patterns()
|
||||
|
||||
|
||||
def academic(now, pre, **options):
|
||||
"""Rule for academic translition."""
|
||||
c, s = now
|
||||
if not s:
|
||||
return c
|
||||
|
||||
ps = pre[1] if pre else None
|
||||
|
||||
marker = False
|
||||
if ps:
|
||||
if s.initial == 11:
|
||||
marker = True
|
||||
elif ps and (REVISED_FINALS[ps.final] + REVISED_INITIALS[s.initial]) in ACADEMIC_AMBIGUOUS_PATTERNS:
|
||||
marker = True
|
||||
|
||||
r = u''
|
||||
if marker:
|
||||
r += '-'
|
||||
r += REVISED_INITIALS[s.initial] + REVISED_VOWELS[s.vowel] + REVISED_FINALS[s.final]
|
||||
return r
|
@ -0,0 +1 @@
|
||||
0.0.1
|
@ -0,0 +1,5 @@
|
||||
from .kanji_to_romaji_module import convert_hiragana_to_katakana, translate_to_romaji, translate_soukon, \
|
||||
translate_long_vowel, translate_soukon_ch, kanji_to_romaji
|
||||
__all__ = ["load_mappings_dict", "convert_hiragana_to_katakana", "convert_katakana_to_hiragana",
|
||||
"translate_to_romaji", "translate_soukon",
|
||||
"translate_long_vowel", "translate_soukon_ch", "kanji_to_romaji"]
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,120 @@
|
||||
{
|
||||
"ぁ": "a",
|
||||
"あ": "a",
|
||||
"ぃ": "i",
|
||||
"い": "i",
|
||||
"ぅ": "u",
|
||||
"う": "u",
|
||||
"ぇ": "e",
|
||||
"え": "e",
|
||||
"ぉ": "o",
|
||||
"お": "o",
|
||||
"か": "ka",
|
||||
"が": "ga",
|
||||
"き": "ki",
|
||||
"きゃ": "kya",
|
||||
"きゅ": "kyu",
|
||||
"きょ": "kyo",
|
||||
"ぎ": "gi",
|
||||
"ぎゃ": "gya",
|
||||
"ぎゅ": "gyu",
|
||||
"ぎょ": "gyo",
|
||||
"く": "ku",
|
||||
"ぐ": "gu",
|
||||
"け": "ke",
|
||||
"げ": "ge",
|
||||
"こ": "ko",
|
||||
"ご": "go",
|
||||
"さ": "sa",
|
||||
"ざ": "za",
|
||||
"し": "shi",
|
||||
"しゃ": "sha",
|
||||
"しゅ": "shu",
|
||||
"しょ": "sho",
|
||||
"じ": "ji",
|
||||
"じゃ": "ja",
|
||||
"じゅ": "ju",
|
||||
"じょ": "jo",
|
||||
"す": "su",
|
||||
"ず": "zu",
|
||||
"せ": "se",
|
||||
"ぜ": "ze",
|
||||
"そ": "so",
|
||||
"ぞ": "zo",
|
||||
"た": "ta",
|
||||
"だ": "da",
|
||||
"ち": "chi",
|
||||
"ちゃ": "cha",
|
||||
"ちゅ": "chu",
|
||||
"ちょ": "cho",
|
||||
"ぢ": "ji",
|
||||
"つ": "tsu",
|
||||
"づ": "zu",
|
||||
"て": "te",
|
||||
"で": "de",
|
||||
"と": "to",
|
||||
"ど": "do",
|
||||
"な": "na",
|
||||
"に": "ni",
|
||||
"にゃ": "nya",
|
||||
"にゅ": "nyu",
|
||||
"にょ": "nyo",
|
||||
"ぬ": "nu",
|
||||
"ね": "ne",
|
||||
"の": "no",
|
||||
"は": "ha",
|
||||
"ば": "ba",
|
||||
"ぱ": "pa",
|
||||
"ひ": "hi",
|
||||
"ひゃ": "hya",
|
||||
"ひゅ": "hyu",
|
||||
"ひょ": "hyo",
|
||||
"び": "bi",
|
||||
"びゃ": "bya",
|
||||
"びゅ": "byu",
|
||||
"びょ": "byo",
|
||||
"ぴ": "pi",
|
||||
"ぴゃ": "pya",
|
||||
"ぴゅ": "pyu",
|
||||
"ぴょ": "pyo",
|
||||
"ふ": "fu",
|
||||
"ぶ": "bu",
|
||||
"ぷ": "pu",
|
||||
"へ": "he",
|
||||
"べ": "be",
|
||||
"ぺ": "pe",
|
||||
"ほ": "ho",
|
||||
"ぼ": "bo",
|
||||
"ぽ": "po",
|
||||
"ま": "ma",
|
||||
"み": "mi",
|
||||
"みゃ": "mya",
|
||||
"みゅ": "myu",
|
||||
"みょ": "myo",
|
||||
"む": "mu",
|
||||
"め": "me",
|
||||
"も": "mo",
|
||||
"や": "ya",
|
||||
"ゆ": "yu",
|
||||
"よ": "yo",
|
||||
"ら": "ra",
|
||||
"り": "ri",
|
||||
"りゃ": "rya",
|
||||
"りゅ": "ryu",
|
||||
"りょ": "ryo",
|
||||
"る": "ru",
|
||||
"れ": "re",
|
||||
"ろ": "ro",
|
||||
"ゎ": "wa",
|
||||
"わ": "wa",
|
||||
"ゐ": "wi",
|
||||
"ゑ": "we",
|
||||
"を": " wo ",
|
||||
"ん": "n",
|
||||
"ゔ": "vu",
|
||||
"ゕ": "ka",
|
||||
"ゖ": "ke",
|
||||
"ゝ": "iteration_mark",
|
||||
"ゞ": "voiced_iteration_mark",
|
||||
"ゟ": "yori"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,18 @@
|
||||
{
|
||||
"今日": {
|
||||
"w_type": "noun",
|
||||
"romaji": "kyou"
|
||||
},
|
||||
"明日": {
|
||||
"w_type": "noun",
|
||||
"romaji": "ashita"
|
||||
},
|
||||
"本": {
|
||||
"w_type": "noun",
|
||||
"romaji": "hon"
|
||||
},
|
||||
"中": {
|
||||
"w_type": "noun",
|
||||
"romaji": "naka"
|
||||
}
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
{
|
||||
"朝日奈丸佳": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Asahina Madoka"
|
||||
},
|
||||
"高海千歌": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Takami Chika"
|
||||
},
|
||||
"鏡音レン": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Kagamine Len"
|
||||
},
|
||||
"鏡音リン": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Kagamine Rin"
|
||||
},
|
||||
"逢坂大河": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Aisaka Taiga"
|
||||
},
|
||||
"水樹奈々": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Mizuki Nana"
|
||||
},
|
||||
"桜内梨子": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Sakurauchi Riko"
|
||||
},
|
||||
"山吹沙綾": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Yamabuki Saaya"
|
||||
},
|
||||
"初音ミク": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Hatsune Miku"
|
||||
},
|
||||
"渡辺曜": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Watanabe You"
|
||||
},
|
||||
"原由実": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Hara Yumi"
|
||||
},
|
||||
"北宇治": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Kita Uji"
|
||||
},
|
||||
"六本木": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Roppongi"
|
||||
},
|
||||
"久美子": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Kumiko"
|
||||
},
|
||||
"政宗": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Masamune"
|
||||
},
|
||||
"小林": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Kobayashi"
|
||||
},
|
||||
"奥寺": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Okudera"
|
||||
},
|
||||
"佐藤": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Satou"
|
||||
},
|
||||
"玲子": {
|
||||
"w_type": "noun",
|
||||
"romaji": "Reiko"
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,644 @@
|
||||
# coding=utf-8
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
import simplejson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
from .models import UnicodeRomajiMapping
|
||||
from .models import KanjiBlock
|
||||
from .models import Particle
|
||||
|
||||
PATH_TO_MODULE = os.path.dirname(__file__)
|
||||
JP_MAPPINGS_PATH = os.path.join(PATH_TO_MODULE, "jp_mappings")
|
||||
|
||||
hiragana_iter_mark = "ゝ"
|
||||
hiragana_voiced_iter_mark = "ゞ"
|
||||
katakana_iter_mark = "ヽ"
|
||||
katakana_voiced_iter_mark = "ヾ"
|
||||
kanji_iteration_mark = "々"
|
||||
|
||||
hirgana_soukon_unicode_char = "っ"
|
||||
katakana_soukon_unicode_char = "ッ"
|
||||
katakana_long_vowel_mark = "ー"
|
||||
|
||||
|
||||
def load_kana_mappings_dict():
|
||||
kana_romaji_mapping = {}
|
||||
for f in os.listdir(JP_MAPPINGS_PATH):
|
||||
if os.path.splitext(f)[1] == ".json" and "kanji" not in f:
|
||||
with open(os.path.join(JP_MAPPINGS_PATH, f), encoding='utf-8') as data_file:
|
||||
kana_romaji_mapping.update(json.load(data_file))
|
||||
return kana_romaji_mapping
|
||||
|
||||
|
||||
def load_kanji_mappings_dict():
|
||||
"""
|
||||
read through all json files that contain "kanji" in filename
|
||||
load json data from files to kanji_romaji_mapping dictionary
|
||||
if the key(kanji char) has already been added to kanji_romaji_mapping then create "other_readings" key
|
||||
"other_readings" will consist of w_type for its key and the new romaji reading for it
|
||||
e.g:
|
||||
{"係り":
|
||||
'w_type': 'noun',
|
||||
'romaji': 'kakari',
|
||||
{'other_readings': {'godan verb stem': 'kakawari'}
|
||||
}
|
||||
:return: dict - kanji to romaji mapping
|
||||
"""
|
||||
|
||||
kanji_romaji_mapping = {}
|
||||
f_list = os.listdir(JP_MAPPINGS_PATH)
|
||||
for f in f_list[:]: # shift all conjugated files to end, lower priority for verb stems
|
||||
if "conjugated" in f:
|
||||
f_list.remove(f)
|
||||
f_list.append(f)
|
||||
|
||||
for f in f_list:
|
||||
if os.path.splitext(f)[1] == ".json" and "kanji" in f:
|
||||
with open(os.path.join(JP_MAPPINGS_PATH, f), encoding='utf-8') as data_file:
|
||||
data_file_dict = json.load(data_file)
|
||||
for k in data_file_dict.keys():
|
||||
if k in kanji_romaji_mapping and \
|
||||
data_file_dict[k]["w_type"] != kanji_romaji_mapping[k]["w_type"]:
|
||||
# if "other_readings" in kanji_romaji_mapping[k] and \
|
||||
# data_file_dict[k]["w_type"] in kanji_romaji_mapping[k]["other_readings"]:
|
||||
# raise
|
||||
|
||||
if "other_readings" not in kanji_romaji_mapping[k]:
|
||||
kanji_romaji_mapping[k]["other_readings"] = {}
|
||||
|
||||
kanji_romaji_mapping[k]["other_readings"][data_file_dict[k]["w_type"]] = \
|
||||
data_file_dict[k]["romaji"]
|
||||
else:
|
||||
kanji_romaji_mapping[k] = data_file_dict[k]
|
||||
return kanji_romaji_mapping
|
||||
|
||||
|
||||
def _convert_hira_kata_char(hira_or_kata_char, h_to_k=True):
|
||||
"""
|
||||
take second last hex character from unicode and add/subtract 6 hex to it to get hiragana/katakana char
|
||||
e.g hiragana u3041 -> 0x3041 + 0x6 = 0x30A1 -> katakana u30A1
|
||||
|
||||
:param hira_or_kata_char: unicode hiragana character
|
||||
:return: converterd hiragana or katakana depending on h_to_k value
|
||||
"""
|
||||
if h_to_k:
|
||||
suffix_offset = 6
|
||||
else:
|
||||
suffix_offset = -6
|
||||
unicode_second_last_char = list(hira_or_kata_char.encode("unicode_escape"))[-2]
|
||||
suffix = hex(int(unicode_second_last_char, 16) + suffix_offset)
|
||||
char_list = list(hira_or_kata_char.encode("unicode_escape"))
|
||||
char_list[-2] = suffix[-1]
|
||||
result_char = "".join(char_list).decode('unicode-escape').encode('utf-8')
|
||||
return result_char
|
||||
|
||||
|
||||
def convert_hiragana_to_katakana(hiragana):
|
||||
converted_str = ""
|
||||
|
||||
for c in hiragana:
|
||||
if is_hiragana(c) or c in [hiragana_iter_mark, hiragana_voiced_iter_mark, hirgana_soukon_unicode_char]:
|
||||
converted_str += _convert_hira_kata_char(c)
|
||||
else:
|
||||
converted_str += c.encode('utf-8')
|
||||
return converted_str.decode("utf-8")
|
||||
|
||||
|
||||
def convert_katakana_to_hiragana(katakana):
|
||||
converted_str = ""
|
||||
|
||||
for c in katakana:
|
||||
if is_katakana(c) or c in [katakana_iter_mark, katakana_voiced_iter_mark,
|
||||
katakana_soukon_unicode_char]:
|
||||
converted_str += _convert_hira_kata_char(c, h_to_k=False)
|
||||
else:
|
||||
converted_str += c.encode('utf-8')
|
||||
return converted_str.decode("utf-8")
|
||||
|
||||
|
||||
def is_hiragana(c):
|
||||
hiragana_starting_unicode = "\u3041"
|
||||
hiragana_ending_unicode = "\u3096"
|
||||
return c not in [hiragana_iter_mark, hiragana_voiced_iter_mark, hirgana_soukon_unicode_char] and \
|
||||
hiragana_starting_unicode <= c <= hiragana_ending_unicode
|
||||
|
||||
|
||||
def is_katakana(c):
|
||||
katakana_starting_unicode = "\u30A1"
|
||||
katakana_ending_unicode = "\u30F6"
|
||||
return c not in [katakana_iter_mark, katakana_voiced_iter_mark,
|
||||
katakana_soukon_unicode_char, katakana_long_vowel_mark] and \
|
||||
katakana_starting_unicode <= c <= katakana_ending_unicode
|
||||
|
||||
|
||||
def is_kanji(c):
|
||||
cjk_start_range = "\u4E00"
|
||||
cjk_end_range = "\u9FD5"
|
||||
if isinstance(c, KanjiBlock):
|
||||
return True
|
||||
else:
|
||||
return c != kanji_iteration_mark and cjk_start_range <= c <= cjk_end_range
|
||||
|
||||
|
||||
def get_char_type(c):
|
||||
"""
|
||||
determine type of passed character by checking if it belongs in a certan unicode range
|
||||
:param c: kana or kanji character
|
||||
:return: type of character
|
||||
"""
|
||||
char_type = None
|
||||
if is_hiragana(c):
|
||||
char_type = "hiragana"
|
||||
elif is_katakana(c):
|
||||
char_type = "katakana"
|
||||
elif is_kanji(c):
|
||||
char_type = "kanji"
|
||||
|
||||
return char_type
|
||||
|
||||
|
||||
def translate_particles(kana_list):
|
||||
"""
|
||||
try to find particles which are in hirgana and turn them in to Particle objects
|
||||
Particle will provide spacing and will be translated in to appropriate romaji (e.g wa instead of ha for は)
|
||||
|
||||
rules (varies depending on the hiragana char):
|
||||
char between two KanjiBlocks(that can be nouns) then assume to be a particle
|
||||
e.g: 私は嬉 -> KanjiBlock(私), は, KanjiBlock(嬉) -> は is particle use wa instead of ha
|
||||
type(Kanji, Hiragana, Katakana) changes adjacent to the char
|
||||
e.g: アパートへくる -> ト, へ, く -> katakana, へ, hiragana -> へ is a particle, use e instead of he
|
||||
char is last char and previous char is a noun
|
||||
e.g: 会いました友達に -> KanjiBlock(友達) which is a noun, に
|
||||
|
||||
:param kana_list: list of kana characters and KanjiBlock objects
|
||||
:return: None; update the kana_list that is passed
|
||||
"""
|
||||
def is_noun(k_block):
|
||||
return hasattr(k_block, "w_type") and ("noun" in k_block.w_type or "pronoun" in k_block.w_type)
|
||||
|
||||
def type_changes(p, n):
|
||||
if get_char_type(p) is not None and get_char_type(n) is not None:
|
||||
return get_char_type(p) != get_char_type(n)
|
||||
else:
|
||||
return False
|
||||
|
||||
def particle_imm_follows(prev_c_, valid_prev_particles):
|
||||
"""
|
||||
check if prev_c is a Particle object
|
||||
check that prev_c is one of the valid_prev_particles
|
||||
e.g: wa particle can't be followed by wa particle again but ni particle can be followed by wa.
|
||||
:param prev_c_: previous character compared to current character in the iteration
|
||||
:param valid_prev_particles: list of previous particles that can be followed by current character.
|
||||
:return:
|
||||
"""
|
||||
return isinstance(prev_c_, Particle) and prev_c_ in valid_prev_particles
|
||||
|
||||
no_hira_char = "\u306E"
|
||||
ha_hira_char = "\u306F"
|
||||
he_hira_char = "\u3078"
|
||||
to_hira_char = "\u3068"
|
||||
ni_hira_char = "\u306B"
|
||||
de_hira_char = "\u3067"
|
||||
mo_hira_char = "\u3082"
|
||||
ga_hira_char = "\u304C"
|
||||
|
||||
no_prtcle = Particle("no")
|
||||
wa_prtcle = Particle("wa")
|
||||
e_prtcle = Particle("e")
|
||||
to_prtcle = Particle("to")
|
||||
ni_prtcle = Particle("ni")
|
||||
de_prtcle = Particle("de")
|
||||
mo_prtcle = Particle("mo")
|
||||
ga_prtcle = Particle("ga")
|
||||
|
||||
for i in range(1, len(kana_list)):
|
||||
is_last_char = False
|
||||
prev_c = kana_list[i - 1]
|
||||
if i == len(kana_list) - 1:
|
||||
is_last_char = True
|
||||
next_c = ""
|
||||
else:
|
||||
next_c = kana_list[i + 1]
|
||||
|
||||
if kana_list[i] == no_hira_char:
|
||||
if (is_noun(prev_c) and is_noun(next_c)) or \
|
||||
type_changes(prev_c, next_c) or \
|
||||
(is_noun(prev_c) and is_last_char):
|
||||
kana_list[i] = no_prtcle
|
||||
|
||||
elif kana_list[i] == ha_hira_char:
|
||||
if (is_noun(prev_c) and isinstance(next_c, KanjiBlock)) or \
|
||||
type_changes(prev_c, next_c) or \
|
||||
particle_imm_follows(prev_c, [e_prtcle, to_prtcle, ni_prtcle, de_prtcle]) or \
|
||||
(is_noun(prev_c) and is_last_char):
|
||||
kana_list[i] = wa_prtcle
|
||||
|
||||
elif kana_list[i] == mo_hira_char:
|
||||
if (is_noun(prev_c) and isinstance(next_c, KanjiBlock)) or \
|
||||
type_changes(prev_c, next_c) or \
|
||||
particle_imm_follows(prev_c, [ni_prtcle, de_prtcle]) or \
|
||||
(is_noun(prev_c) and is_last_char):
|
||||
kana_list[i] = mo_prtcle
|
||||
|
||||
elif kana_list[i] in [he_hira_char, to_hira_char, ni_hira_char, de_hira_char, ga_hira_char] and \
|
||||
(is_noun(prev_c) and isinstance(next_c, KanjiBlock)) or \
|
||||
type_changes(prev_c, next_c) or \
|
||||
(is_noun(prev_c) and is_last_char):
|
||||
|
||||
if kana_list[i] == he_hira_char:
|
||||
kana_list[i] = e_prtcle
|
||||
|
||||
elif kana_list[i] == to_hira_char:
|
||||
kana_list[i] = to_prtcle
|
||||
|
||||
elif kana_list[i] == ni_hira_char:
|
||||
kana_list[i] = ni_prtcle
|
||||
|
||||
elif kana_list[i] == de_hira_char:
|
||||
kana_list[i] = de_prtcle
|
||||
|
||||
elif kana_list[i] == ga_hira_char:
|
||||
kana_list[i] = ga_prtcle
|
||||
|
||||
|
||||
def translate_kanji_iteration_mark(kana_list):
|
||||
"""
|
||||
translate kanji_iteration_mark: 々
|
||||
e.g:
|
||||
在々: zaizai
|
||||
:param kana_list: unicode consisting of kana and kanji chars
|
||||
:return: unicode with kanji iteration marks translated
|
||||
"""
|
||||
prev_c = ""
|
||||
for i in range(0, len(kana_list)):
|
||||
if kana_list[i] == kanji_iteration_mark:
|
||||
kana_list[i] = prev_c.romaji.strip()
|
||||
prev_c = kana_list[i]
|
||||
|
||||
|
||||
def get_type_if_verb_stem(curr_chars):
|
||||
"""
|
||||
get verb type for given verb stem. verb types can be ichidan, godan or None.
|
||||
No stem for irregulars
|
||||
:param curr_chars: kanji chars that is a verb stem
|
||||
:return: type of verb stem
|
||||
"""
|
||||
v_type = None
|
||||
|
||||
if "verb stem" in UnicodeRomajiMapping.kanji_mapping[curr_chars]["w_type"]:
|
||||
v_type = UnicodeRomajiMapping.kanji_mapping[curr_chars]["w_type"]
|
||||
|
||||
elif "other_readings" in UnicodeRomajiMapping.kanji_mapping[curr_chars]:
|
||||
if "godan verb stem" in UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"]:
|
||||
v_type = "godan verb"
|
||||
elif "ichidan verb stem" in UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"]:
|
||||
v_type = "ichidan verb"
|
||||
|
||||
return v_type
|
||||
|
||||
|
||||
def check_for_verb_stem_ending(kana_list, curr_chars, start_pos, char_len):
|
||||
"""
|
||||
if the given curr_chars has a verb stem reading then try to match it with an one of the listed verb endings
|
||||
otherwise return/use its .romaji property
|
||||
|
||||
e.g:
|
||||
kana_list = [KanjiBlock(灯り), ま, し, た]
|
||||
curr_chars = 灯り can be verb stem reading
|
||||
try and match 灯り with an ending within kana_list
|
||||
灯り + ました matches
|
||||
romaji is tomori + mashita (this modifies kana_list to remove matched ending)
|
||||
kana_list = [tomorimashita]
|
||||
|
||||
kana_list = [KanjiBlock(灯り), を, 見ます]
|
||||
curr_chars = 灯り can be verb stem reading
|
||||
try and match 灯り with an ending within kana_list
|
||||
no matching ending
|
||||
romaji is akari
|
||||
kana_list = [akari, を, 見ます]
|
||||
|
||||
:param kana_list:
|
||||
:param curr_chars: KanjiBlock current characters to parse out of entire kana_list
|
||||
:param start_pos:
|
||||
:param char_len:
|
||||
:return: ending kanji, ending romaji; both will be None if ending not found
|
||||
"""
|
||||
endings = OrderedDict({})
|
||||
endings["ませんでした"] = "masen deshita"
|
||||
endings["ませんで"] = "masende"
|
||||
endings["なさるな"] = "nasaruna"
|
||||
endings["なかった"] = "nakatta"
|
||||
endings["れて"] = "rete"
|
||||
endings["ましょう"] = "masho"
|
||||
endings["ました"] = "mashita"
|
||||
endings["まして"] = "mashite"
|
||||
endings["ません"] = "masen"
|
||||
endings["ないで"] = "naide"
|
||||
endings["なさい"] = "nasai"
|
||||
endings["ます"] = "mas"
|
||||
endings["よう"] = "yo" # ichidan
|
||||
endings["ない"] = "nai"
|
||||
endings["た"] = "ta" # ichidan
|
||||
endings["て"] = "te" # ichidan
|
||||
endings["ろ"] = "ro" # ichidan
|
||||
endings["う"] = ""
|
||||
|
||||
dict_entry = None
|
||||
|
||||
if "verb stem" in UnicodeRomajiMapping.kanji_mapping[curr_chars]["w_type"]:
|
||||
dict_entry = UnicodeRomajiMapping.kanji_mapping[curr_chars]
|
||||
|
||||
elif "other_readings" in UnicodeRomajiMapping.kanji_mapping[curr_chars]:
|
||||
|
||||
if "godan verb stem" in UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"]:
|
||||
dict_entry = {
|
||||
"romaji": UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"]["godan verb stem"]
|
||||
}
|
||||
elif "ichidan verb stem" in UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"]:
|
||||
dict_entry = {
|
||||
"romaji": UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"]["ichidan verb stem"]
|
||||
}
|
||||
e_k = None
|
||||
e_r = None
|
||||
if dict_entry is not None:
|
||||
for e in endings.keys():
|
||||
possible_conj = curr_chars + e
|
||||
actual_conj = "".join(kana_list[start_pos: (start_pos + char_len + len(e))])
|
||||
if possible_conj == actual_conj:
|
||||
e_k = e
|
||||
e_r = endings[e] + " "
|
||||
break
|
||||
|
||||
return e_k, e_r
|
||||
|
||||
|
||||
def has_non_verb_stem_reading(curr_chars):
|
||||
"""
|
||||
check if curr_chars has an alternative reading aside from the verb stem
|
||||
:param curr_chars: unicode kanji chars to check
|
||||
:return: true/false depending on if curr_chars has a verb stem reading
|
||||
"""
|
||||
res = False
|
||||
|
||||
if "verb stem" not in UnicodeRomajiMapping.kanji_mapping[curr_chars]["w_type"]:
|
||||
res = True
|
||||
|
||||
elif "other_readings" in UnicodeRomajiMapping.kanji_mapping[curr_chars]:
|
||||
if any(["verb stem" not in ork
|
||||
for ork in UnicodeRomajiMapping.kanji_mapping[curr_chars]["other_readings"].keys()]):
|
||||
res = True
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_verb_stem_romaji(verb_stem_kanji):
|
||||
"""
|
||||
find romaji for verb stem within kanji_mapping
|
||||
:param verb_stem_kanji: unicode verb stem kanji
|
||||
:return: romaji for verb stem kanji
|
||||
"""
|
||||
romaji = None
|
||||
if "verb stem" in UnicodeRomajiMapping.kanji_mapping[verb_stem_kanji]["w_type"]:
|
||||
romaji = UnicodeRomajiMapping.kanji_mapping[verb_stem_kanji]["romaji"]
|
||||
elif "other_readings" in UnicodeRomajiMapping.kanji_mapping[verb_stem_kanji]:
|
||||
for k in UnicodeRomajiMapping.kanji_mapping[verb_stem_kanji]["other_readings"].keys():
|
||||
if "verb stem" in k:
|
||||
romaji = UnicodeRomajiMapping.kanji_mapping[verb_stem_kanji]["other_readings"][k]
|
||||
break
|
||||
|
||||
return romaji
|
||||
|
||||
|
||||
def prepare_kanjiblocks(kchar_list):
|
||||
"""
|
||||
create and replace matched Kanji characters that are within kanji_mapping with KanjiBlock
|
||||
KanjiBlock will be used for spacing and particle translation later
|
||||
if the kanji found is a verb stem then try to find an ending to match it with what's in kchar_list
|
||||
:param kchar_list: list containing kana and kanji characters
|
||||
:return: kchar_list with all found Kanji characters turned in to KanjiBlock objects
|
||||
"""
|
||||
if len(UnicodeRomajiMapping.kanji_mapping) == 0:
|
||||
UnicodeRomajiMapping.kanji_mapping = load_kanji_mappings_dict()
|
||||
|
||||
max_char_len = len(kchar_list)
|
||||
kana_list = list(kchar_list)
|
||||
|
||||
start_pos = 0
|
||||
while start_pos < max_char_len:
|
||||
char_len = len(kana_list) - start_pos
|
||||
while char_len > 0:
|
||||
curr_chars = "".join(kana_list[start_pos: (start_pos + char_len)])
|
||||
if curr_chars in UnicodeRomajiMapping.kanji_mapping:
|
||||
verb_stem_type = get_type_if_verb_stem(curr_chars)
|
||||
ending_match_found = False
|
||||
if verb_stem_type is not None:
|
||||
ending_kana, ending_romaji = check_for_verb_stem_ending(kana_list, curr_chars, start_pos, char_len)
|
||||
if ending_kana is not None and ending_romaji is not None:
|
||||
ending_match_found = True
|
||||
conjugated_val = {
|
||||
"romaji": get_verb_stem_romaji(curr_chars) + ending_romaji,
|
||||
"w_type": "conjugated " + verb_stem_type
|
||||
}
|
||||
|
||||
for i in range(start_pos + char_len - 1 + len(ending_kana), start_pos - 1, -1):
|
||||
del kana_list[i]
|
||||
|
||||
kana_list.insert(start_pos,
|
||||
KanjiBlock(curr_chars + ending_kana, conjugated_val))
|
||||
|
||||
if ending_match_found is False and has_non_verb_stem_reading(curr_chars):
|
||||
for i in range(start_pos + char_len - 1, start_pos - 1, -1):
|
||||
del kana_list[i]
|
||||
kana_list.insert(start_pos,
|
||||
KanjiBlock(curr_chars, UnicodeRomajiMapping.kanji_mapping[curr_chars]))
|
||||
char_len -= 1
|
||||
start_pos += 1
|
||||
return kana_list
|
||||
|
||||
|
||||
def translate_kanji(kana_list):
|
||||
i = 0
|
||||
while i < len(kana_list):
|
||||
if type(kana_list[i]) == KanjiBlock:
|
||||
kana_list[i] = kana_list[i].romaji
|
||||
i += 1
|
||||
|
||||
kana = "".join(kana_list)
|
||||
return kana
|
||||
|
||||
|
||||
def prep_kanji(kana):
|
||||
kana_list = list(kana)
|
||||
if any([is_kanji(k) for k in kana]):
|
||||
kana_list = prepare_kanjiblocks(kana)
|
||||
translate_kanji_iteration_mark(kana_list)
|
||||
|
||||
return kana_list
|
||||
|
||||
|
||||
def translate_to_romaji(kana):
|
||||
"""
|
||||
translate hiragana, katakana, typographic, and fhw latin
|
||||
:param kana: unicode kana(+kanji) characters
|
||||
:return: translated base kana characters to romaji as well as typographic, and fhw latin
|
||||
"""
|
||||
if len(UnicodeRomajiMapping.kana_mapping) == 0:
|
||||
UnicodeRomajiMapping.kana_mapping = load_kana_mappings_dict()
|
||||
|
||||
max_char_len = 2
|
||||
|
||||
for char_len in range(max_char_len, 0, -1):
|
||||
start_pos = 0
|
||||
while start_pos < len(kana) - char_len + 1:
|
||||
curr_chars = kana[start_pos: (start_pos + char_len)]
|
||||
if curr_chars in UnicodeRomajiMapping.kana_mapping:
|
||||
kana = kana.replace(curr_chars, UnicodeRomajiMapping.kana_mapping[curr_chars], 1)
|
||||
if len(UnicodeRomajiMapping.kana_mapping[curr_chars]) == 0:
|
||||
start_pos -= 1
|
||||
start_pos += 1
|
||||
|
||||
while " " in kana:
|
||||
kana = kana.replace(" ", " ")
|
||||
kana = kana.strip()
|
||||
|
||||
lines = kana.split("\n")
|
||||
for i in range(0, len(lines)):
|
||||
lines[i] = lines[i].strip()
|
||||
kana = "\n".join(lines)
|
||||
return kana
|
||||
|
||||
|
||||
def translate_soukon(partial_kana):
|
||||
"""
|
||||
translate both hiragana and katakana soukon: っ, ッ; repeats next consonant
|
||||
e.g:
|
||||
ちょっと willl be choっto by the time iit is passed to this method and then becomes chotto
|
||||
:param partial_kana: partially translated kana with base kana chars already translated to romaji
|
||||
:return: partial kana with soukon translated
|
||||
"""
|
||||
prev_char = ""
|
||||
|
||||
for c in reversed(partial_kana):
|
||||
if c == hirgana_soukon_unicode_char or c == katakana_soukon_unicode_char: # assuming that soukon can't be last
|
||||
partial_kana = prev_char[0].join(partial_kana.rsplit(c, 1))
|
||||
prev_char = c
|
||||
return partial_kana
|
||||
|
||||
|
||||
def translate_long_vowel(partial_kana):
|
||||
"""
|
||||
translate katakana long vowel ー; repeats previous vowel
|
||||
e.g:
|
||||
メール will be meーru by the time it is passed to this method and then becomes meeru
|
||||
:param partial_kana: partially translated kana with base kana chars already translated to romaji
|
||||
:return: partial kana with long vowel translated
|
||||
"""
|
||||
prev_c = ""
|
||||
for c in partial_kana:
|
||||
if c == katakana_long_vowel_mark:
|
||||
if prev_c[-1] in list("aeio"):
|
||||
partial_kana = partial_kana.replace(c, prev_c[-1], 1)
|
||||
else:
|
||||
partial_kana = partial_kana.replace(c, "", 1)
|
||||
prev_c = c
|
||||
return partial_kana
|
||||
|
||||
|
||||
def translate_soukon_ch(kana):
|
||||
"""
|
||||
if soukon(mini-tsu) is followed by chi then soukon romaji becomes 't' sound
|
||||
e.g: ko-soukon-chi -> kotchi instead of kocchi
|
||||
:param kana:
|
||||
:return:
|
||||
"""
|
||||
|
||||
prev_char = ""
|
||||
hiragana_chi_unicode_char = "\u3061"
|
||||
katakana_chi_unicode_char = "\u30C1"
|
||||
partial_kana = kana
|
||||
for c in reversed(kana):
|
||||
if c == hirgana_soukon_unicode_char or c == katakana_soukon_unicode_char: # assuming that soukon can't be last
|
||||
if prev_char == hiragana_chi_unicode_char or prev_char == katakana_chi_unicode_char:
|
||||
partial_kana = "t".join(partial_kana.rsplit(c, 1))
|
||||
prev_char = c
|
||||
return partial_kana
|
||||
|
||||
|
||||
def _translate_dakuten_equivalent_char(kana_char):
|
||||
dakuten_mapping = {
|
||||
"か": "が", "き": "ぎ", "く": "ぐ", "け": "げ", "こ": "ご",
|
||||
"さ": "ざ", "し": "じ", "す": "ず", "せ": "ぜ", "そ": "ぞ",
|
||||
"た": "だ", "ち": "ぢ", "つ": "づ", "て": "で", "と": "ど",
|
||||
"は": "ば", "ひ": "び", "ふ": "ぶ", "へ": "べ", "ほ": "ぼ",
|
||||
"タ": "ダ", "チ": "ヂ", "ツ": "ヅ", "テ": "デ", "ト": "ド",
|
||||
"カ": "ガ", "キ": "ギ", "ク": "グ", "ケ": "ゲ", "コ": "ゴ",
|
||||
"サ": "ザ", "シ": "ジ", "ス": "ズ", "セ": "ゼ", "ソ": "ゾ",
|
||||
"ハ": "バ", "ヒ": "ビ", "フ": "ブ", "ヘ": "ベ", "ホ": "ボ"
|
||||
}
|
||||
|
||||
dakuten_equiv = ""
|
||||
if kana_char in dakuten_mapping:
|
||||
dakuten_equiv = dakuten_mapping[kana_char]
|
||||
|
||||
return dakuten_equiv
|
||||
|
||||
|
||||
def translate_dakuten_equivalent(kana_char):
|
||||
"""
|
||||
translate hiragana and katakana character to their dakuten equivalent
|
||||
e.g:
|
||||
ヒ: ビ
|
||||
く: ぐ
|
||||
み: ""
|
||||
:param kana_char: unicode kana char
|
||||
:return: dakuten equivalent if it exists otherwise empty string
|
||||
"""
|
||||
return _translate_dakuten_equivalent_char(kana_char)
|
||||
|
||||
|
||||
def translate_kana_iteration_mark(kana):
|
||||
"""
|
||||
translate hiragana and katakana iteration marks: ゝ, ゞ, ヽ, ヾ
|
||||
e.g:
|
||||
こゝ: koko
|
||||
タヾ: tada
|
||||
かゞみち: kagaみち
|
||||
:param kana: unicode consisting of kana chars
|
||||
:return: unicode with kana iteration marks translated
|
||||
"""
|
||||
prev_char = ""
|
||||
partial_kana = kana
|
||||
for c in kana:
|
||||
if c == hiragana_iter_mark or c == katakana_iter_mark:
|
||||
partial_kana = prev_char.join(partial_kana.split(c, 1))
|
||||
elif c == hiragana_voiced_iter_mark or c == katakana_voiced_iter_mark:
|
||||
partial_kana = translate_dakuten_equivalent(prev_char).join(partial_kana.split(c, 1))
|
||||
else:
|
||||
prev_char = c
|
||||
return partial_kana
|
||||
|
||||
|
||||
def kanji_to_romaji(kana):
|
||||
pk = translate_kana_iteration_mark(kana)
|
||||
pk = translate_soukon_ch(pk)
|
||||
pk_list = prep_kanji(pk)
|
||||
translate_particles(pk_list)
|
||||
pk = translate_kanji(pk_list)
|
||||
pk = translate_to_romaji(pk)
|
||||
pk = translate_soukon(pk)
|
||||
r = translate_long_vowel(pk)
|
||||
return r.replace("\\\\", "\\")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
(kanji_to_romaji(("".join(sys.argv[1:])).decode('unicode-escape')))
|
||||
else:
|
||||
print("Missing Kanji/Kana character argument\n" \
|
||||
"e.g: kanji_to_romaji.py \u30D2")
|
@ -0,0 +1,29 @@
|
||||
class KanjiBlock(str):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
obj = str.__new__(cls, "@")
|
||||
kanji = args[0]
|
||||
kanji_dict = args[1]
|
||||
|
||||
obj.kanji = kanji
|
||||
if len(kanji) == 1:
|
||||
obj.romaji = " " + kanji_dict["romaji"]
|
||||
else:
|
||||
if "verb stem" in kanji_dict["w_type"]:
|
||||
obj.romaji = " " + kanji_dict["romaji"]
|
||||
else:
|
||||
obj.romaji = " " + kanji_dict["romaji"] + " "
|
||||
|
||||
if "other_readings" in kanji_dict:
|
||||
obj.w_type = [kanji_dict["w_type"]]
|
||||
obj.w_type.extend(
|
||||
[k for k in kanji_dict["other_readings"].keys()]
|
||||
)
|
||||
else:
|
||||
obj.w_type = kanji_dict["w_type"]
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return self.kanji.encode("unicode_escape")
|
||||
|
||||
def __str__(self):
|
||||
return self.romaji.encode("utf-8")
|
@ -0,0 +1,6 @@
|
||||
class Particle(str):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
particle_str = args[0]
|
||||
obj = str.__new__(cls, " " + particle_str + " ")
|
||||
obj.pname = particle_str
|
||||
return obj
|
@ -0,0 +1,4 @@
|
||||
# noinspection PyClassHasNoInit
|
||||
class UnicodeRomajiMapping: # caching
|
||||
kana_mapping = {}
|
||||
kanji_mapping = {}
|
@ -0,0 +1,5 @@
|
||||
from .UnicodeRomajiMapping import UnicodeRomajiMapping
|
||||
from .KanjiBlock import KanjiBlock
|
||||
from .Particle import Particle
|
||||
|
||||
__all__ = ["UnicodeRomajiMapping", "KanjiBlock", "Particle"]
|
Binary file not shown.
After Width: | Height: | Size: 4.2 KiB |
@ -0,0 +1,40 @@
|
||||
import asyncio
|
||||
import asyncpg
|
||||
import config
|
||||
import os
|
||||
|
||||
conn = None
|
||||
|
||||
class Install:
|
||||
dbi: asyncpg.Connection
|
||||
|
||||
async def run(self):
|
||||
self.dbi = await asyncpg.connect(**config.DATABASE)
|
||||
args = os.sys.argv
|
||||
if "--force" in args:
|
||||
await self.drop_table()
|
||||
|
||||
await self.create_table()
|
||||
|
||||
async def drop_table(self):
|
||||
await self.dbi.execute("DROP TABLE IF EXISTS embedding_search_title_index;")
|
||||
print("Table dropped")
|
||||
|
||||
async def create_table(self):
|
||||
await self.dbi.execute("""
|
||||
CREATE TABLE embedding_search_title_index (
|
||||
id SERIAL PRIMARY KEY,
|
||||
sha1 VARCHAR(40) NOT NULL UNIQUE,
|
||||
title TEXT NOT NULL,
|
||||
rev_id INT8 NOT NULL,
|
||||
embedding VECTOR(%d) NOT NULL
|
||||
);
|
||||
""" % (config.EMBEDDING_VECTOR_SIZE))
|
||||
await self.dbi.execute("CREATE INDEX embedding_search_title_index_embedding_idx ON embedding_search_title_index USING ivfflat (embedding vector_cosine_ops);")
|
||||
print("Table created")
|
||||
|
||||
if __name__ == "__main__":
|
||||
install = Install()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(install.run())
|
@ -0,0 +1,16 @@
|
||||
aiohttp==3.8.4
|
||||
jieba==0.42.1
|
||||
pypinyin==0.37.0
|
||||
simplejson==3.17.0
|
||||
beautifulsoup4==4.11.2
|
||||
markdownify==0.11.6
|
||||
asyncpg==0.27.0
|
||||
aiofiles==23.1.0
|
||||
pgvector==0.1.6
|
||||
websockets==11.0
|
||||
PyJWT==2.6.0
|
||||
asyncpg-stubs==0.27.0
|
||||
sqlalchemy==2.0.9
|
||||
aiohttp-sse-client2==0.3.0
|
||||
OpenCC==1.1.6
|
||||
event-emitter-asyncio==1.0.4
|
@ -0,0 +1,51 @@
|
||||
import asyncio
|
||||
from typing import TypedDict
|
||||
from aiohttp import web
|
||||
import asyncpg
|
||||
import config
|
||||
import api.route
|
||||
import utils.web
|
||||
from service.database import DatabaseService
|
||||
from service.mediawiki_api import MediaWikiApi
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from service.tiktoken import TikTokenService
|
||||
|
||||
async def index(request: web.Request):
|
||||
return utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request)
|
||||
|
||||
async def init_mw_api(app: web.Application):
|
||||
mw_api = MediaWikiApi.create()
|
||||
if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD:
|
||||
await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD)
|
||||
|
||||
site_meta = await mw_api.get_site_meta()
|
||||
|
||||
print("Connected to Wiki %s, Robot username: %s" % (site_meta["sitename"], site_meta["user"]))
|
||||
|
||||
async def init_database(app: web.Application):
|
||||
dbs = await DatabaseService.create(app)
|
||||
print("Database connected.")
|
||||
|
||||
async def init_tiktoken(app: web.Application):
|
||||
await TikTokenService.create()
|
||||
print("Tiktoken model loaded.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
app = web.Application()
|
||||
|
||||
if config.DATABASE:
|
||||
app.on_startup.append(init_database)
|
||||
|
||||
if config.MW_API:
|
||||
app.on_startup.append(init_mw_api)
|
||||
|
||||
if config.OPENAI_TOKEN:
|
||||
app.on_startup.append(init_tiktoken)
|
||||
|
||||
app.router.add_route('*', '/', index)
|
||||
api.route.init(app)
|
||||
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop)
|
||||
|
@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
from typing import Optional, Tuple, TypedDict
|
||||
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationChunkModel
|
||||
import sys
|
||||
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
|
||||
|
||||
import config
|
||||
import utils.config
|
||||
|
||||
from aiohttp import web
|
||||
from api.model.embedding_search.title_collection import TitleCollectionModel
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from service.database import DatabaseService
|
||||
from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService
|
||||
from service.mediawiki_api import MediaWikiApi
|
||||
from service.openai_api import OpenAIApi
|
||||
from service.tiktoken import TikTokenService
|
||||
|
||||
|
||||
class ChatCompleteServiceResponse(TypedDict):
|
||||
message: str
|
||||
message_tokens: int
|
||||
total_tokens: int
|
||||
finish_reason: str
|
||||
conversation_id: int
|
||||
delta_data: dict
|
||||
|
||||
|
||||
class ChatCompleteService:
|
||||
def __init__(self, dbs: DatabaseService, title: str):
|
||||
self.dbs = dbs
|
||||
|
||||
self.title = title
|
||||
self.base_title = title.split("/")[0]
|
||||
|
||||
self.embedding_search = EmbeddingSearchService(dbs, title)
|
||||
self.conversation_helper = ConversationHelper(dbs)
|
||||
self.conversation_chunk_helper = ConversationChunkHelper(dbs)
|
||||
|
||||
self.conversation_info: Optional[ConversationModel] = None
|
||||
self.conversation_chunk: Optional[ConversationChunkModel] = None
|
||||
|
||||
self.tiktoken: TikTokenService = None
|
||||
|
||||
self.mwapi = MediaWikiApi.create()
|
||||
self.openai_api = OpenAIApi.create()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.tiktoken = await TikTokenService.create()
|
||||
|
||||
await self.embedding_search.__aenter__()
|
||||
await self.conversation_helper.__aenter__()
|
||||
await self.conversation_chunk_helper.__aenter__()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.embedding_search.__aexit__(exc_type, exc, tb)
|
||||
await self.conversation_helper.__aexit__(exc_type, exc, tb)
|
||||
await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb)
|
||||
|
||||
async def page_index_exists(self):
|
||||
return await self.embedding_search.page_index_exists(False)
|
||||
|
||||
async def get_question_tokens(self, question: str):
|
||||
return await self.tiktoken.get_tokens(question)
|
||||
|
||||
async def chat_complete(self, question: str, on_message: Optional[callable] = None, on_extracted_doc: Optional[callable] = None,
|
||||
conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None,
|
||||
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse:
|
||||
if user_id is not None:
|
||||
user_id = int(user_id)
|
||||
|
||||
self.conversation_info = None
|
||||
if conversation_id is not None:
|
||||
conversation_id = int(conversation_id)
|
||||
self.conversation_info = await self.conversation_helper.get_conversation(conversation_id)
|
||||
|
||||
delta_data = {}
|
||||
|
||||
self.conversation_chunk = None
|
||||
message_log = []
|
||||
if self.conversation_info is not None:
|
||||
if self.conversation_info.user_id != user_id:
|
||||
raise web.HTTPUnauthorized()
|
||||
|
||||
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(conversation_id)
|
||||
|
||||
# If the conversation is too long, we need to make a summary
|
||||
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
|
||||
summary, tokens = await self.make_summary(self.conversation_chunk.message_data)
|
||||
new_message_log = [
|
||||
{"role": "summary", "content": summary, "tokens": tokens}
|
||||
]
|
||||
|
||||
self.conversation_chunk = await self.conversation_chunk_helper.add(conversation_id, new_message_log, tokens)
|
||||
|
||||
delta_data["conversation_chunk_id"] = self.conversation_chunk.id
|
||||
|
||||
message_log = []
|
||||
for message in self.conversation_chunk.message_data:
|
||||
message_log.append({
|
||||
"role": message["role"],
|
||||
"content": message["content"],
|
||||
})
|
||||
|
||||
if question_tokens is None:
|
||||
question_tokens = await self.get_question_tokens(question)
|
||||
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
|
||||
question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
|
||||
# If the question is too long, we need to truncate it
|
||||
raise web.HTTPRequestEntityTooLarge()
|
||||
|
||||
extract_doc = None
|
||||
if embedding_search is not None:
|
||||
extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
|
||||
if extract_doc is not None:
|
||||
if on_extracted_doc is not None:
|
||||
await on_extracted_doc(extract_doc)
|
||||
|
||||
question_tokens = token_usage
|
||||
doc_prompt_content = "\n".join(["%d. %s" % (
|
||||
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(extract_doc)])
|
||||
|
||||
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", {
|
||||
"content": doc_prompt_content})
|
||||
message_log.append({"role": "user", "content": doc_prompt})
|
||||
|
||||
system_prompt = utils.config.get_prompt("chat", "system_prompt")
|
||||
|
||||
# Start chat complete
|
||||
if on_message is not None:
|
||||
response = await self.openai_api.chat_complete_stream(question, system_prompt, message_log, on_message)
|
||||
else:
|
||||
response = await self.openai_api.chat_complete(question, system_prompt, message_log)
|
||||
|
||||
if self.conversation_info is None:
|
||||
# Create a new conversation
|
||||
message_log_list = [
|
||||
{"role": "user", "content": question, "tokens": question_tokens},
|
||||
{"role": "assistant",
|
||||
"content": response["message"], "tokens": response["message_tokens"]},
|
||||
]
|
||||
title = None
|
||||
try:
|
||||
title, token_usage = await self.make_title(message_log_list)
|
||||
delta_data["title"] = title
|
||||
except Exception as e:
|
||||
title = config.CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE
|
||||
print(str(e), file=sys.stderr)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
|
||||
total_token_usage = question_tokens + response["message_tokens"]
|
||||
|
||||
title_info = self.embedding_search.title_info
|
||||
self.conversation_info = await self.conversation_helper.add(user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
|
||||
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage)
|
||||
else:
|
||||
# Update the conversation chunk
|
||||
await self.conversation_helper.refresh_updated_at(conversation_id)
|
||||
|
||||
self.conversation_chunk.message_data.append(
|
||||
{"role": "user", "content": question, "tokens": question_tokens})
|
||||
self.conversation_chunk.message_data.append(
|
||||
{"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]})
|
||||
flag_modified(self.conversation_chunk, "message_data")
|
||||
self.conversation_chunk.tokens += question_tokens + \
|
||||
response["message_tokens"]
|
||||
|
||||
await self.conversation_chunk_helper.update(self.conversation_chunk)
|
||||
|
||||
return ChatCompleteServiceResponse(
|
||||
message=response["message"],
|
||||
message_tokens=response["message_tokens"],
|
||||
total_tokens=response["total_tokens"],
|
||||
finish_reason=response["finish_reason"],
|
||||
conversation_id=self.conversation_info.id,
|
||||
delta_data=delta_data
|
||||
)
|
||||
|
||||
async def set_latest_point_cost(self, point_cost: int) -> bool:
|
||||
if self.conversation_chunk is None:
|
||||
return False
|
||||
|
||||
if len(self.conversation_chunk.message_data) == 0:
|
||||
return False
|
||||
|
||||
for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
|
||||
if self.conversation_chunk.message_data[i]["role"] == "assistant":
|
||||
self.conversation_chunk.message_data[i]["point_cost"] = point_cost
|
||||
flag_modified(self.conversation_chunk, "message_data")
|
||||
await self.conversation_chunk_helper.update(self.conversation_chunk)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def make_summary(self, message_log_list: list) -> tuple[str, int]:
|
||||
chat_log: list[str] = []
|
||||
|
||||
for message_data in message_log_list:
|
||||
if message_data["role"] == 'summary':
|
||||
chat_log.append(message_data["content"])
|
||||
elif message_data["role"] == 'assistant':
|
||||
chat_log.append(
|
||||
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}')
|
||||
else:
|
||||
chat_log.append(f'User: {message_data["content"]}')
|
||||
|
||||
chat_log_str = '\n'.join(chat_log)
|
||||
|
||||
summary_system_prompt = utils.config.get_prompt(
|
||||
"summary", "system_prompt")
|
||||
summary_prompt = utils.config.get_prompt(
|
||||
"summary", "prompt", {"content": chat_log_str})
|
||||
|
||||
response = await self.openai_api.chat_complete(summary_prompt, summary_system_prompt)
|
||||
|
||||
return response["message"], response["message_tokens"]
|
||||
|
||||
async def make_title(self, message_log_list: list) -> tuple[str, int]:
|
||||
chat_log: list[str] = []
|
||||
for message_data in message_log_list:
|
||||
if message_data["role"] == 'assistant':
|
||||
chat_log.append(
|
||||
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}')
|
||||
elif message_data["role"] == 'user':
|
||||
chat_log.append(f'User: {message_data["content"]}')
|
||||
|
||||
chat_log_str = '\n'.join(chat_log)
|
||||
|
||||
title_system_prompt = utils.config.get_prompt("title", "system_prompt")
|
||||
title_prompt = utils.config.get_prompt(
|
||||
"title", "prompt", {"content": chat_log_str})
|
||||
|
||||
response = await self.openai_api.chat_complete(title_prompt, title_system_prompt)
|
||||
return response["message"], response["message_tokens"]
|
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from urllib.parse import quote_plus
|
||||
from aiohttp import web
|
||||
import asyncpg
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
import config
|
||||
|
||||
def get_dsn():
|
||||
return "postgresql+asyncpg://%s:%s@%s:%s/%s" % (
|
||||
quote_plus(config.DATABASE["user"]),
|
||||
quote_plus(config.DATABASE["password"]),
|
||||
config.DATABASE["host"],
|
||||
config.DATABASE["port"],
|
||||
quote_plus(config.DATABASE["database"]))
|
||||
|
||||
class DatabaseService:
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
async def create(app: web.Application = None) -> DatabaseService:
|
||||
if app is None:
|
||||
if DatabaseService.instance is None:
|
||||
DatabaseService.instance = DatabaseService()
|
||||
await DatabaseService.instance.init()
|
||||
return DatabaseService.instance
|
||||
else:
|
||||
if "database" not in app:
|
||||
instance = DatabaseService()
|
||||
await instance.init()
|
||||
app["database"] = instance
|
||||
|
||||
return app["database"]
|
||||
|
||||
def __init__(self):
|
||||
self.pool: asyncpg.pool.Pool = None
|
||||
self.engine: AsyncEngine = None
|
||||
self.create_session: async_sessionmaker[AsyncSession] = None
|
||||
|
||||
async def init(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
self.pool = asyncpg.create_pool(**config.DATABASE, loop=loop)
|
||||
await self.pool.__aenter__()
|
||||
|
||||
engine = create_async_engine(get_dsn(), echo=config.DEBUG)
|
||||
self.engine = engine
|
||||
self.create_session = async_sessionmaker(engine, expire_on_commit=False)
|
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypedDict
|
||||
from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel
|
||||
from api.model.embedding_search.title_index import TitleIndexHelper
|
||||
from api.model.embedding_search.page_index import PageIndexHelper
|
||||
from service.database import DatabaseService
|
||||
from service.mediawiki_api import MediaWikiApi
|
||||
from service.openai_api import OpenAIApi
|
||||
from service.tiktoken import TikTokenService
|
||||
from utils.wiki import getWikiSentences
|
||||
|
||||
class EmbeddingSearchArgs(TypedDict):
|
||||
limit: Optional[int]
|
||||
in_collection: Optional[bool]
|
||||
distance_limit: Optional[float]
|
||||
|
||||
class EmbeddingSearchService:
|
||||
def __init__(self, dbs: DatabaseService, title: str):
|
||||
self.dbs = dbs
|
||||
|
||||
self.title = title
|
||||
self.base_title = title.split("/")[0]
|
||||
|
||||
self.title_index = TitleIndexHelper(dbs)
|
||||
self.title_collection = TitleCollectionHelper(dbs)
|
||||
self.page_index: PageIndexHelper = None
|
||||
|
||||
self.tiktoken: TikTokenService = None
|
||||
|
||||
self.mwapi = MediaWikiApi.create()
|
||||
self.openai_api = OpenAIApi.create()
|
||||
|
||||
self.page_id: int = None
|
||||
self.collection_id: int = None
|
||||
|
||||
self.title_info: dict = None
|
||||
self.collection_info: TitleCollectionModel = None
|
||||
|
||||
self.page_info: dict = None
|
||||
self.unindexed_docs: list = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.tiktoken = await TikTokenService.create()
|
||||
|
||||
await self.title_index.__aenter__()
|
||||
await self.title_collection.__aenter__()
|
||||
|
||||
self.title_info = await self.title_index.find_by_title(self.title)
|
||||
if self.title_info is not None:
|
||||
self.page_id = self.title_info["page_id"]
|
||||
self.collection_id = self.title_info["collection_id"]
|
||||
|
||||
self.page_index = PageIndexHelper(
|
||||
self.dbs, self.collection_id, self.page_id)
|
||||
await self.page_index.__aenter__()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.title_index.__aexit__(exc_type, exc, tb)
|
||||
await self.title_collection.__aexit__(exc_type, exc, tb)
|
||||
|
||||
if self.page_index is not None:
|
||||
await self.page_index.__aexit__(exc_type, exc, tb)
|
||||
|
||||
async def page_index_exists(self, check_table = True):
|
||||
if check_table:
|
||||
return self.page_index and await self.page_index.table_exists()
|
||||
else:
|
||||
return self.page_index is not None
|
||||
|
||||
async def load_page_info(self, reload=False):
|
||||
if self.page_info is None or reload:
|
||||
self.page_info = await self.mwapi.get_page_info(self.title)
|
||||
|
||||
async def should_update_page_index(self):
|
||||
await self.load_page_info()
|
||||
|
||||
if (self.title_info is not None and await self.page_index_exists() and
|
||||
self.title_info["title"] == self.page_info["title"] and self.title_info["rev_id"] == self.page_info["lastrevid"]):
|
||||
# Not changed
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def prepare_update_index(self):
|
||||
# Check rev_id
|
||||
await self.load_page_info()
|
||||
|
||||
if not await self.should_update_page_index():
|
||||
return False
|
||||
|
||||
self.page_id = self.page_info["pageid"]
|
||||
|
||||
# Create collection
|
||||
self.collection_info = await self.title_collection.find_by_title(self.base_title)
|
||||
if self.collection_info is None:
|
||||
self.collection_id = await self.title_collection.add(self.base_title)
|
||||
if self.collection_id is None:
|
||||
raise Exception("Failed to create title collection")
|
||||
else:
|
||||
self.collection_id = self.collection_info.id
|
||||
|
||||
self.page_index = PageIndexHelper(
|
||||
self.dbs, self.collection_id, self.page_id)
|
||||
await self.page_index.__aenter__()
|
||||
await self.page_index.init_table()
|
||||
|
||||
page_content = await self.mwapi.parse_page(self.title)
|
||||
|
||||
self.sentences = getWikiSentences(page_content)
|
||||
|
||||
self.unindexed_docs = await self.page_index.get_unindexed_doc(self.sentences, with_temporary=False)
|
||||
|
||||
return True
|
||||
|
||||
async def get_unindexed_tokens(self):
|
||||
if self.unindexed_docs is None:
|
||||
return 0
|
||||
else:
|
||||
tokens = 0
|
||||
for doc in self.unindexed_docs:
|
||||
if "text" in doc:
|
||||
tokens += await self.tiktoken.get_tokens(doc["text"])
|
||||
|
||||
return tokens
|
||||
|
||||
async def update_page_index(self, on_progress=None):
|
||||
if self.unindexed_docs is None:
|
||||
return False
|
||||
|
||||
total_token_usage = 0
|
||||
|
||||
async def embedding_doc(doc_chunk):
|
||||
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
|
||||
await self.page_index.index_doc(doc_chunk)
|
||||
|
||||
return token_usage
|
||||
|
||||
if len(self.unindexed_docs) > 0:
|
||||
if on_progress is not None:
|
||||
await on_progress(0, len(self.unindexed_docs))
|
||||
|
||||
chunk_limit = 500
|
||||
|
||||
chunk_len = 0
|
||||
processed_len = 0
|
||||
doc_chunk = []
|
||||
for doc in self.unindexed_docs:
|
||||
chunk_len += len(doc)
|
||||
|
||||
if chunk_len > chunk_limit:
|
||||
total_token_usage += await embedding_doc(doc_chunk)
|
||||
processed_len += len(doc_chunk)
|
||||
if on_progress is not None:
|
||||
await on_progress(processed_len, len(self.unindexed_docs))
|
||||
|
||||
doc_chunk = []
|
||||
chunk_len = len(doc)
|
||||
|
||||
doc_chunk.append(doc)
|
||||
|
||||
if len(doc_chunk) > 0:
|
||||
total_token_usage += await embedding_doc(doc_chunk)
|
||||
if on_progress is not None:
|
||||
await on_progress(len(self.unindexed_docs), len(self.unindexed_docs))
|
||||
|
||||
await self.page_index.remove_outdated_doc(self.sentences)
|
||||
|
||||
# Update database
|
||||
if self.title_info is None:
|
||||
doc_chunk = [{"text": self.title}]
|
||||
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
|
||||
total_token_usage += token_usage
|
||||
|
||||
embedding = doc_chunk[0]["embedding"]
|
||||
|
||||
await self.title_index.add(self.page_info["title"],
|
||||
self.page_id,
|
||||
self.page_info["lastrevid"],
|
||||
self.collection_id,
|
||||
embedding)
|
||||
else:
|
||||
if self.title != self.page_info["title"]:
|
||||
self.title = self.page_info["title"]
|
||||
|
||||
doc_chunk = [{"text": self.title}]
|
||||
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
|
||||
total_token_usage += token_usage
|
||||
|
||||
embedding = doc_chunk[0]["embedding"]
|
||||
|
||||
await self.title_index.update_title_data(self.page_id,
|
||||
self.title,
|
||||
self.page_info["lastrevid"],
|
||||
self.collection_id,
|
||||
embedding)
|
||||
else:
|
||||
await self.title_index.update_rev_id(self.title, self.page_info["lastrevid"])
|
||||
|
||||
if (self.collection_info is None or
|
||||
(self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)):
|
||||
await self.title_collection.set_page_id(self.base_title, self.page_id)
|
||||
|
||||
return total_token_usage
|
||||
|
||||
async def search(self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6):
|
||||
if self.page_index is None:
|
||||
raise Exception("Page index is not initialized")
|
||||
|
||||
query_doc = [{"text": query}]
|
||||
query_doc, token_usage = await self.openai_api.get_embeddings(query_doc)
|
||||
query_embedding = query_doc[0]["embedding"]
|
||||
|
||||
if query_embedding is None:
|
||||
return [], token_usage
|
||||
|
||||
res = await self.page_index.search_text_embedding(query_embedding, in_collection, limit)
|
||||
if res:
|
||||
filtered = []
|
||||
for one in res:
|
||||
if one["distance"] < distance_limit:
|
||||
filtered.append(dict(one))
|
||||
return filtered, token_usage
|
||||
else:
|
||||
return res, token_usage
|
@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from event_emitter_asyncio.EventEmitter import EventEmitter
|
||||
|
||||
class EventService(EventEmitter):
|
||||
instance: EventService = None
|
||||
|
||||
@staticmethod
|
||||
def create() -> EventService:
|
||||
if EventService.instance is None:
|
||||
EventService.instance = EventService()
|
||||
return EventService.instance
|
@ -0,0 +1,242 @@
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
import aiohttp
|
||||
import config
|
||||
|
||||
class MediaWikiApiException(Exception):
|
||||
def __init__(self, info: str, code: Optional[str] = None) -> None:
|
||||
super().__init__(info)
|
||||
self.info = info
|
||||
self.code = code
|
||||
self.message = self.info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
class MediaWikiPageNotFoundException(MediaWikiApiException):
|
||||
pass
|
||||
|
||||
class MediaWikiApi:
|
||||
cookie_jar = aiohttp.CookieJar(unsafe=True)
|
||||
|
||||
@staticmethod
|
||||
def create():
|
||||
return MediaWikiApi(config.MW_API)
|
||||
|
||||
def __init__(self, api_url: str):
|
||||
self.api_url = api_url
|
||||
self.login_time = 0
|
||||
self.login_identity = None
|
||||
|
||||
async def get_page_info(self, title: str):
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
params = {
|
||||
"action": "query",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"prop": "info",
|
||||
"titles": title,
|
||||
"inprop": "url"
|
||||
}
|
||||
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
if "missing" in data["query"]["pages"][0]:
|
||||
raise MediaWikiPageNotFoundException()
|
||||
|
||||
return data["query"]["pages"][0]
|
||||
|
||||
async def parse_page(self, title: str):
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
params = {
|
||||
"action": "parse",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"prop": "text",
|
||||
"page": title,
|
||||
"disableeditsection": "true",
|
||||
"disabletoc": "true",
|
||||
"disablelimitreport": "true",
|
||||
}
|
||||
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
return data["parse"]["text"]
|
||||
|
||||
async def get_site_meta(self):
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
params = {
|
||||
"action": "query",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"meta": "siteinfo|userinfo",
|
||||
"siprop": "general"
|
||||
}
|
||||
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
ret = {
|
||||
"sitename": "Unknown",
|
||||
"user": "Anonymous",
|
||||
}
|
||||
if "query" in data:
|
||||
if "general" in data["query"]:
|
||||
ret["sitename"] = data["query"]["general"]["sitename"]
|
||||
if "userinfo" in data["query"]:
|
||||
ret["user"] = data["query"]["userinfo"]["name"]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def get_token(self, token_type: str):
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
params = {
|
||||
"action": "query",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"meta": "tokens",
|
||||
"type": token_type
|
||||
}
|
||||
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
return data["query"]["tokens"][token_type + "token"]
|
||||
|
||||
async def robot_login(self, username: str, password: str):
|
||||
token = await self.get_token("login")
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
post_data = {
|
||||
"action": "login",
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
"lgname": username,
|
||||
"lgpassword": password,
|
||||
"lgtoken": token,
|
||||
}
|
||||
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
if "result" not in data["login"] or data["login"]["result"] != "Success":
|
||||
raise MediaWikiApiException("Login failed")
|
||||
|
||||
self.login_time = time.time()
|
||||
self.login_identity = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
|
||||
return True
|
||||
|
||||
async def refresh_login(self):
|
||||
if self.login_identity is None:
|
||||
return False
|
||||
if time.time() - self.login_time > 10:
|
||||
return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
|
||||
|
||||
async def chat_complete_user_info(self, user_id: int):
|
||||
await self.refresh_login()
|
||||
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
params = {
|
||||
"action": "chatcompletebot",
|
||||
"method": "userinfo",
|
||||
"userid": user_id,
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
}
|
||||
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
if data["error"]["code"] == "user-not-found":
|
||||
raise MediaWikiPageNotFoundException(data["error"]["info"], data["error"]["code"])
|
||||
else:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
return data["chatcompletebot"]["userinfo"]
|
||||
|
||||
async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> str:
|
||||
await self.refresh_login()
|
||||
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
post_data = {
|
||||
"action": "chatcompletebot",
|
||||
"method": "reportusage",
|
||||
"step": "start",
|
||||
"userid": int(user_id),
|
||||
"useraction": user_action,
|
||||
"tokens": int(tokens),
|
||||
"extractlines": int(extractlines),
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
}
|
||||
# Filter out None values
|
||||
post_data = {k: v for k, v in post_data.items() if v is not None}
|
||||
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
print(data)
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
return data["chatcompletebot"]["reportusage"]["transactionid"]
|
||||
|
||||
async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
|
||||
await self.refresh_login()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
post_data = {
|
||||
"action": "chatcompletebot",
|
||||
"method": "reportusage",
|
||||
"step": "end",
|
||||
"transactionid": transaction_id,
|
||||
"tokens": tokens,
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
}
|
||||
# Filter out None values
|
||||
post_data = {k: v for k, v in post_data.items() if v is not None}
|
||||
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
return data["chatcompletebot"]["reportusage"]["success"]
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
|
||||
async def chat_complete_cancel_transaction(self, transaction_id: str, error: Optional[str] = None):
|
||||
await self.refresh_login()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
|
||||
post_data = {
|
||||
"action": "chatcompletebot",
|
||||
"method": "reportusage",
|
||||
"step": "cancel",
|
||||
"transactionid": transaction_id,
|
||||
"error": error,
|
||||
"format": "json",
|
||||
"formatversion": "2",
|
||||
}
|
||||
# Filter out None values
|
||||
post_data = {k: v for k, v in post_data.items() if v is not None}
|
||||
async with session.get(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
|
||||
data = await resp.json()
|
||||
if "error" in data:
|
||||
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
|
||||
|
||||
return data["chatcompletebot"]["reportusage"]["success"]
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from typing import TypedDict
|
||||
|
||||
import aiohttp
|
||||
import config
|
||||
import numpy as np
|
||||
from aiohttp_sse_client2 import client as sse_client
|
||||
|
||||
from service.tiktoken import TikTokenService
|
||||
|
||||
class ChatCompleteMessageLog(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatCompleteResponse(TypedDict):
|
||||
message: str
|
||||
prompt_tokens: int
|
||||
message_tokens: int
|
||||
total_tokens: int
|
||||
finish_reason: str
|
||||
|
||||
class OpenAIApi:
|
||||
@staticmethod
|
||||
def create():
|
||||
return OpenAIApi(config.OPENAI_API or "https://api.openai.com", config.OPENAI_TOKEN)
|
||||
|
||||
def __init__(self, api_url: str, token: str):
|
||||
self.api_url = api_url
|
||||
self.token = token
|
||||
|
||||
async def get_embeddings(self, doc_list: list):
|
||||
token_usage = 0
|
||||
async with aiohttp.ClientSession() as session:
|
||||
text_list = [doc["text"] for doc in doc_list]
|
||||
params = {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": text_list,
|
||||
}
|
||||
async with session.post(self.api_url + "/v1/embeddings",
|
||||
headers={"Authorization": f"Bearer {self.token}"},
|
||||
json=params,
|
||||
timeout=30,
|
||||
proxy=config.REQUEST_PROXY) as resp:
|
||||
|
||||
data = await resp.json()
|
||||
|
||||
for one_data in data["data"]:
|
||||
embedding = one_data["embedding"]
|
||||
index = one_data["index"]
|
||||
|
||||
if index < len(doc_list):
|
||||
if embedding is not None:
|
||||
embedding = np.array(embedding)
|
||||
doc_list[index]["embedding"] = embedding
|
||||
|
||||
token_usage = int(data["usage"]["total_tokens"])
|
||||
|
||||
return (doc_list, token_usage)
|
||||
|
||||
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
|
||||
summaryContent = None
|
||||
|
||||
messageList: list[ChatCompleteMessageLog] = []
|
||||
for message in conversation:
|
||||
if message["role"] == "summary":
|
||||
summaryContent = message["content"]
|
||||
elif message["role"] == "user" or message["role"] == "assistant":
|
||||
messageList.append(message)
|
||||
|
||||
if summaryContent:
|
||||
system_prompt += "\n\n" + summaryContent
|
||||
|
||||
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
|
||||
messageList.append(ChatCompleteMessageLog(role="user", content=question))
|
||||
|
||||
return messageList
|
||||
|
||||
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
|
||||
messageList = await self.make_message_list(question, system_prompt, conversation)
|
||||
|
||||
params = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": messageList,
|
||||
"user": user,
|
||||
}
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.api_url + "/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.token}"},
|
||||
json=params,
|
||||
timeout=30,
|
||||
proxy=config.REQUEST_PROXY) as resp:
|
||||
|
||||
data = await resp.json()
|
||||
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
choice = data["choices"][0]
|
||||
|
||||
message = choice["message"]["content"]
|
||||
finish_reason = choice["finish_reason"]
|
||||
|
||||
prompt_tokens = int(data["usage"]["prompt_tokens"])
|
||||
message_tokens = int(data["usage"]["completion_tokens"])
|
||||
total_tokens = int(data["usage"]["total_tokens"])
|
||||
|
||||
return ChatCompleteResponse(message=message,
|
||||
prompt_tokens=prompt_tokens,
|
||||
message_tokens=message_tokens,
|
||||
total_tokens=total_tokens,
|
||||
finish_reason=finish_reason)
|
||||
|
||||
return None
|
||||
|
||||
async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
|
||||
tiktoken = await TikTokenService.create()
|
||||
|
||||
messageList = await self.make_message_list(question, system_prompt, conversation)
|
||||
|
||||
prompt_tokens = 0
|
||||
for message in messageList:
|
||||
prompt_tokens += await tiktoken.get_tokens(message["content"])
|
||||
|
||||
params = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": messageList,
|
||||
"stream": True,
|
||||
"user": user,
|
||||
}
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
res_message: list[str] = []
|
||||
finish_reason = None
|
||||
|
||||
async with sse_client.EventSource(
|
||||
self.api_url + "/v1/chat/completions",
|
||||
option={
|
||||
"method": "POST"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {self.token}"},
|
||||
json=params,
|
||||
proxy=config.REQUEST_PROXY
|
||||
) as session:
|
||||
async for event in session:
|
||||
"""
|
||||
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
|
||||
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Thank"},"index":0,"finish_reason":null}]}
|
||||
{"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
|
||||
[DONE]
|
||||
"""
|
||||
content_started = False
|
||||
|
||||
if event.data == "[DONE]":
|
||||
break
|
||||
elif event.data[0] == "{" and event.data[-1] == "}":
|
||||
data = json.loads(event.data)
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
choice = data["choices"][0]
|
||||
|
||||
if choice["finish_reason"] is not None:
|
||||
finish_reason = choice["finish_reason"]
|
||||
|
||||
delta_content = choice["delta"]
|
||||
if "content" in delta_content:
|
||||
delta_message: str = delta_content["content"]
|
||||
|
||||
# Skip empty lines before content
|
||||
if not content_started:
|
||||
if delta_message.replace("\n", "") == "":
|
||||
continue
|
||||
else:
|
||||
content_started = True
|
||||
|
||||
res_message.append(delta_message)
|
||||
|
||||
if config.DEBUG:
|
||||
print(delta_message, end="", flush=True)
|
||||
|
||||
if on_message is not None:
|
||||
await on_message(delta_message)
|
||||
|
||||
res_message_str = "".join(res_message)
|
||||
message_tokens = await tiktoken.get_tokens(res_message_str)
|
||||
total_tokens = prompt_tokens + message_tokens
|
||||
|
||||
return ChatCompleteResponse(message=res_message_str,
|
||||
prompt_tokens=prompt_tokens,
|
||||
message_tokens=message_tokens,
|
||||
total_tokens=total_tokens,
|
||||
finish_reason=finish_reason)
|
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from aiohttp import web
|
||||
import tiktoken_async
|
||||
|
||||
class TikTokenService:
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
async def create() -> TikTokenService:
|
||||
if TikTokenService.instance is None:
|
||||
TikTokenService.instance = TikTokenService()
|
||||
await TikTokenService.instance.init()
|
||||
return TikTokenService.instance
|
||||
|
||||
def __init__(self):
|
||||
self.enc: tiktoken_async.Encoding = None
|
||||
|
||||
async def init(self):
|
||||
self.enc = await tiktoken_async.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
async def get_tokens(self, text: str):
|
||||
encoded = self.enc.encode(text)
|
||||
if encoded:
|
||||
return len(encoded)
|
||||
else:
|
||||
return 0
|
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
import asyncpg
|
||||
|
||||
class SimpleQueryBuilder:
|
||||
def __init__(self):
|
||||
self._table_name = ""
|
||||
self._select = ["*"]
|
||||
self._where = []
|
||||
self._having = []
|
||||
self._order_by = None
|
||||
self._order_by_desc = False
|
||||
|
||||
def table(self, table_name: str):
|
||||
self._table_name = table_name
|
||||
return self
|
||||
|
||||
def fields(self, fields: list[str]):
|
||||
self.select = fields
|
||||
return self
|
||||
|
||||
def where(self, where: str, condition: str, param):
|
||||
self._where.append((where, condition, param))
|
||||
return self
|
||||
|
||||
def having(self, having: str, condition: str, param):
|
||||
self._having.append((having, condition, param))
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
sql = "SELECT %s FROM %s" % (", ".join(self._select), self._table_name)
|
||||
|
||||
params = []
|
||||
paramsLen = 0
|
||||
|
||||
if len(self._where) > 0:
|
||||
sql += " WHERE "
|
||||
for where, condition, param in self._where:
|
||||
params.append(param)
|
||||
paramsLen += 1
|
||||
sql += "%s %s $%d AND " % (where, condition, paramsLen)
|
||||
|
||||
if self._order_by is not None:
|
||||
sql += " ORDER BY %s %s" % (self._order_by, "DESC" if self._order_by_desc else "ASC")
|
||||
|
||||
if len(self._having) > 0:
|
||||
sql += " HAVING "
|
||||
for having, condition, param in self._having:
|
||||
params.append(param)
|
||||
paramsLen += 1
|
||||
sql += "%s %s $%d AND " % (having, condition, paramsLen)
|
@ -0,0 +1,443 @@
|
||||
--
|
||||
-- PostgreSQL database dump
|
||||
--
|
||||
|
||||
-- Dumped from database version 15.2 (Ubuntu 15.2-1.pgdg20.04+1)
|
||||
-- Dumped by pg_dump version 15.2 (Ubuntu 15.2-1.pgdg20.04+1)
|
||||
|
||||
SET statement_timeout = 0;
|
||||
SET lock_timeout = 0;
|
||||
SET idle_in_transaction_session_timeout = 0;
|
||||
SET client_encoding = 'UTF8';
|
||||
SET standard_conforming_strings = on;
|
||||
SELECT pg_catalog.set_config('search_path', '', false);
|
||||
SET check_function_bodies = false;
|
||||
SET xmloption = content;
|
||||
SET client_min_messages = warning;
|
||||
SET row_security = off;
|
||||
|
||||
--
|
||||
-- Name: vector; Type: EXTENSION; Schema: -; Owner: -
|
||||
--
|
||||
|
||||
CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA public;
|
||||
|
||||
|
||||
--
|
||||
-- Name: EXTENSION vector; Type: COMMENT; Schema: -; Owner:
|
||||
--
|
||||
|
||||
COMMENT ON EXTENSION vector IS 'vector data type and ivfflat access method';
|
||||
|
||||
|
||||
SET default_tablespace = '';
|
||||
|
||||
SET default_table_access_method = heap;
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation; Type: TABLE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE TABLE public.chat_complete_conversation (
|
||||
id integer NOT NULL,
|
||||
user_id integer NOT NULL,
|
||||
title character varying(255) DEFAULT ''::character varying NOT NULL,
|
||||
updated_at timestamp without time zone NOT NULL,
|
||||
pinned boolean DEFAULT false NOT NULL,
|
||||
rev_id bigint NOT NULL
|
||||
);
|
||||
|
||||
|
||||
ALTER TABLE public.chat_complete_conversation OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk; Type: TABLE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE TABLE public.chat_complete_conversation_chunk (
|
||||
id integer NOT NULL,
|
||||
conversation_id bigint NOT NULL,
|
||||
message_data text,
|
||||
updated_at timestamp without time zone NOT NULL
|
||||
);
|
||||
|
||||
|
||||
ALTER TABLE public.chat_complete_conversation_chunk OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk_conversation_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.chat_complete_conversation_chunk_conversation_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.chat_complete_conversation_chunk_conversation_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk_conversation_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.chat_complete_conversation_chunk_conversation_id_seq OWNED BY public.chat_complete_conversation_chunk.conversation_id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.chat_complete_conversation_chunk_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.chat_complete_conversation_chunk_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.chat_complete_conversation_chunk_id_seq OWNED BY public.chat_complete_conversation_chunk.id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.chat_complete_conversation_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.chat_complete_conversation_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.chat_complete_conversation_id_seq OWNED BY public.chat_complete_conversation.id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index; Type: TABLE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE TABLE public.embedding_search_page_index (
|
||||
id integer NOT NULL,
|
||||
page_id bigint NOT NULL,
|
||||
sha1 character varying(40) NOT NULL,
|
||||
text text NOT NULL,
|
||||
text_len integer NOT NULL,
|
||||
markdown text,
|
||||
markdown_len integer,
|
||||
embedding public.vector(1536) NOT NULL,
|
||||
temp_doc_session_id bigint
|
||||
);
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_page_index OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.embedding_search_page_index_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_page_index_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.embedding_search_page_index_id_seq OWNED BY public.embedding_search_page_index.id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index_temp_doc_session_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.embedding_search_page_index_temp_doc_session_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_page_index_temp_doc_session_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index_temp_doc_session_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.embedding_search_page_index_temp_doc_session_id_seq OWNED BY public.embedding_search_page_index.temp_doc_session_id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_temp_doc_session; Type: TABLE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE TABLE public.embedding_search_temp_doc_session (
|
||||
id integer NOT NULL,
|
||||
user_id bigint NOT NULL,
|
||||
expired_at timestamp without time zone NOT NULL
|
||||
);
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_temp_doc_session OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_temp_doc_session_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.embedding_search_temp_doc_session_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_temp_doc_session_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_temp_doc_session_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.embedding_search_temp_doc_session_id_seq OWNED BY public.embedding_search_temp_doc_session.id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index; Type: TABLE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE TABLE public.embedding_search_title_index (
|
||||
id integer NOT NULL,
|
||||
sha1 character varying(40) NOT NULL,
|
||||
title text NOT NULL,
|
||||
rev_id bigint NOT NULL,
|
||||
embedding public.vector(1536) NOT NULL,
|
||||
page_id bigint NOT NULL,
|
||||
parent_page_id bigint
|
||||
);
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_title_index OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index_id_seq; Type: SEQUENCE; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE SEQUENCE public.embedding_search_title_index_id_seq
|
||||
AS integer
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
|
||||
ALTER TABLE public.embedding_search_title_index_id_seq OWNER TO hyperzlib;
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER SEQUENCE public.embedding_search_title_index_id_seq OWNED BY public.embedding_search_title_index.id;
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation id; Type: DEFAULT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.chat_complete_conversation ALTER COLUMN id SET DEFAULT nextval('public.chat_complete_conversation_id_seq'::regclass);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk id; Type: DEFAULT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.chat_complete_conversation_chunk ALTER COLUMN id SET DEFAULT nextval('public.chat_complete_conversation_chunk_id_seq'::regclass);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index id; Type: DEFAULT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_page_index ALTER COLUMN id SET DEFAULT nextval('public.embedding_search_page_index_id_seq'::regclass);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_temp_doc_session id; Type: DEFAULT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_temp_doc_session ALTER COLUMN id SET DEFAULT nextval('public.embedding_search_temp_doc_session_id_seq'::regclass);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index id; Type: DEFAULT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_title_index ALTER COLUMN id SET DEFAULT nextval('public.embedding_search_title_index_id_seq'::regclass);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk chat_complete_conversation_chunk_pk; Type: CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.chat_complete_conversation_chunk
|
||||
ADD CONSTRAINT chat_complete_conversation_chunk_pk PRIMARY KEY (id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation chat_complete_conversation_pkey; Type: CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.chat_complete_conversation
|
||||
ADD CONSTRAINT chat_complete_conversation_pkey PRIMARY KEY (id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index embedding_search_page_index_pkey; Type: CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_page_index
|
||||
ADD CONSTRAINT embedding_search_page_index_pkey PRIMARY KEY (id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_temp_doc_session embedding_search_temp_doc_session_pkey; Type: CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_temp_doc_session
|
||||
ADD CONSTRAINT embedding_search_temp_doc_session_pkey PRIMARY KEY (id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index embedding_search_title_index_pkey; Type: CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_title_index
|
||||
ADD CONSTRAINT embedding_search_title_index_pkey PRIMARY KEY (id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index embedding_search_title_index_sha1_key; Type: CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_title_index
|
||||
ADD CONSTRAINT embedding_search_title_index_sha1_key UNIQUE (sha1);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk_updated_at_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX chat_complete_conversation_chunk_updated_at_idx ON public.chat_complete_conversation_chunk USING btree (updated_at);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_pinned_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX chat_complete_conversation_pinned_idx ON public.chat_complete_conversation USING btree (pinned);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_updated_at_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX chat_complete_conversation_updated_at_idx ON public.chat_complete_conversation USING btree (updated_at);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_user_id_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX chat_complete_conversation_user_id_idx ON public.chat_complete_conversation USING btree (user_id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index_embedding_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX embedding_search_page_index_embedding_idx ON public.embedding_search_page_index USING ivfflat (embedding public.vector_cosine_ops);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index_temp_doc_session_id_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX embedding_search_page_index_temp_doc_session_id_idx ON public.embedding_search_page_index USING btree (temp_doc_session_id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index_embedding_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX embedding_search_title_index_embedding_idx ON public.embedding_search_title_index USING ivfflat (embedding public.vector_cosine_ops);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index_page_id_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX embedding_search_title_index_page_id_idx ON public.embedding_search_title_index USING btree (page_id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_title_index_parent_page_id_idx; Type: INDEX; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
CREATE INDEX embedding_search_title_index_parent_page_id_idx ON public.embedding_search_title_index USING btree (parent_page_id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: chat_complete_conversation_chunk chat_complete_conversation_chunk_fk; Type: FK CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.chat_complete_conversation_chunk
|
||||
ADD CONSTRAINT chat_complete_conversation_chunk_fk FOREIGN KEY (conversation_id) REFERENCES public.chat_complete_conversation(id);
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index embedding_search_page_index_fk; Type: FK CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_page_index
|
||||
ADD CONSTRAINT embedding_search_page_index_fk FOREIGN KEY (page_id) REFERENCES public.embedding_search_page_index(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
--
|
||||
-- Name: embedding_search_page_index embedding_search_page_index_fk_1; Type: FK CONSTRAINT; Schema: public; Owner: hyperzlib
|
||||
--
|
||||
|
||||
ALTER TABLE ONLY public.embedding_search_page_index
|
||||
ADD CONSTRAINT embedding_search_page_index_fk_1 FOREIGN KEY (temp_doc_session_id) REFERENCES public.embedding_search_temp_doc_session(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
--
|
||||
-- PostgreSQL database dump complete
|
||||
--
|
||||
|
@ -0,0 +1,17 @@
|
||||
let params = {
|
||||
title: '灵能世界',
|
||||
question: '写一段关于方清辉的介绍',
|
||||
token: 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImlzZWthaXdpa2kifQ.eyJpc3MiOiJtd2NoYXRjb21wbGV0ZSIsInN1YiI6MSwibmFtZSI6Ikh5cGVyemxpYiIsImlhdCI6MTY4MTQ1Mjk2NiwiZXhwIjoxNjgxNTM5MzY2fQ.U0yBb8Qw9WXAe2PzfRbgWdQPH62xLqbwet7Jev0VcZ4'
|
||||
}
|
||||
|
||||
let ws = new WebSocket('ws://localhost:8144/chatcomplete/message?' + new URLSearchParams(params));
|
||||
|
||||
ws.addEventListener('message', (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data?.event === 'output') {
|
||||
console.log(data.text);
|
||||
} else {
|
||||
console.log(data)
|
||||
}
|
||||
});
|
||||
ws.addEventListener('error', console.log);
|
@ -0,0 +1 @@
|
||||
__all__ = ["config", "text", "web", "wiki"]
|
@ -0,0 +1,72 @@
|
||||
import re
|
||||
|
||||
def isAscii(inputStr):
|
||||
return bool(re.match(r"^[\x00-\xff]+$", inputStr))
|
||||
|
||||
|
||||
def isAsciiPunc(inputStr):
|
||||
return bool(re.match(r"^[\x20-\x2f\x3a-\x40\x5b-\x60]+$", inputStr))
|
||||
|
||||
|
||||
def isAsciiChar(char):
|
||||
return ord(char) <= 255
|
||||
|
||||
|
||||
def isAsciiPuncChar(char):
|
||||
charCode = ord(char)
|
||||
if 0x20 <= charCode <= 0x2f or 0x3a <= charCode <= 0x40 or 0x5b <= charCode <= 0x60:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class CHARTYPE:
|
||||
ASCII = 0
|
||||
ASCII_PUNC = 1
|
||||
UNICODE = 2
|
||||
|
||||
|
||||
def getCharType(char):
|
||||
if isAsciiChar(char):
|
||||
if isAsciiPuncChar(char):
|
||||
return CHARTYPE.ASCII_PUNC
|
||||
else:
|
||||
return CHARTYPE.ASCII
|
||||
else:
|
||||
return CHARTYPE.UNICODE
|
||||
|
||||
|
||||
def replaceCJKPunc(string):
|
||||
table = {ord(f): ord(t) for f, t in zip(
|
||||
u',。!?【】()《》%#@&·1234567890',
|
||||
u',.!?[]() %#@& 1234567890')}
|
||||
return string.translate(table)
|
||||
|
||||
|
||||
def splitAscii(string):
|
||||
if len(string) == 0:
|
||||
return string
|
||||
string = replaceCJKPunc(string)
|
||||
|
||||
lastCharType = getCharType(string[0])
|
||||
segList = []
|
||||
startPos = 0
|
||||
endPos = 0
|
||||
buffer = []
|
||||
for char in string:
|
||||
if char == " ":
|
||||
if endPos > startPos:
|
||||
segList.append(string[startPos:endPos])
|
||||
startPos = endPos + 1
|
||||
else:
|
||||
currentCharType = getCharType(char)
|
||||
if lastCharType != currentCharType:
|
||||
if endPos > startPos:
|
||||
segList.append(string[startPos:endPos])
|
||||
startPos = endPos
|
||||
lastCharType = currentCharType
|
||||
endPos += 1
|
||||
|
||||
if endPos > startPos:
|
||||
segList.append(string[startPos:])
|
||||
return segList
|
@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
from typing import Any, Optional, Dict
|
||||
from aiohttp import web
|
||||
import jwt
|
||||
import config
|
||||
|
||||
ParamRule = Dict[str, Any]
|
||||
|
||||
class ParamInvalidException(Exception):
|
||||
def __init__(self, param_list: list[str], rules: dict[str, ParamRule]):
|
||||
self.code = "param_invalid"
|
||||
self.param_list = param_list
|
||||
self.rules = rules
|
||||
param_list_str = "'" + ("', '".join(param_list)) + "'"
|
||||
super().__init__(f"Param invalid: {param_list_str}")
|
||||
|
||||
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
|
||||
params: dict[str, Any] = {}
|
||||
for key, value in request.query.items():
|
||||
params[key] = value
|
||||
if request.method == 'POST':
|
||||
if request.headers.get('content-type') == 'application/json':
|
||||
data = await request.json()
|
||||
if data is not None and data is dict:
|
||||
for key, value in data.items():
|
||||
params[key] = value
|
||||
else:
|
||||
data = await request.post()
|
||||
for key, value in data.items():
|
||||
params[key] = value
|
||||
|
||||
if rules is not None:
|
||||
invalid_params: list[str] = []
|
||||
for key, rule in rules.items():
|
||||
if "required" in rule and rule["required"] and params[key] is None:
|
||||
invalid_params.append(key)
|
||||
continue
|
||||
|
||||
if key in params:
|
||||
if "type" in rule:
|
||||
if rule["type"] is dict:
|
||||
if params[key] not in rule["type"]:
|
||||
invalid_params.append(key)
|
||||
continue
|
||||
try:
|
||||
if rule["type"] == int:
|
||||
params[key] = int(params[key])
|
||||
elif rule["type"] == float:
|
||||
params[key] = float(params[key])
|
||||
elif rule["type"] == bool:
|
||||
val = params[key].lower()
|
||||
if val == "false" or val == "0":
|
||||
params[key] = False
|
||||
else:
|
||||
params[key] = True
|
||||
except ValueError:
|
||||
invalid_params.append(key)
|
||||
continue
|
||||
else:
|
||||
if "default" in rule:
|
||||
params[key] = rule["default"]
|
||||
else:
|
||||
params[key] = None
|
||||
|
||||
if len(invalid_params) > 0:
|
||||
raise ParamInvalidException(invalid_params, rules)
|
||||
|
||||
return params
|
||||
|
||||
async def api_response(status, data=None, error=None, warning=None, http_status=200, request: Optional[web.Request] = None):
|
||||
ret = { "status": status }
|
||||
if data:
|
||||
ret["data"] = data
|
||||
if error:
|
||||
ret["error"] = error
|
||||
if warning:
|
||||
ret["warning"] = warning
|
||||
if request and is_websocket(request):
|
||||
ret["event"] = "response"
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
await ws.send_json(ret)
|
||||
await ws.close()
|
||||
else:
|
||||
return web.json_response(ret, status=http_status)
|
||||
|
||||
def is_websocket(request: web.Request):
|
||||
return request.headers.get('Upgrade', '').lower() == 'websocket'
|
||||
|
||||
# Auth decorator
|
||||
def token_auth(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
request: web.Request = args[0]
|
||||
|
||||
jwt_token = None
|
||||
sk_token = None
|
||||
params = await get_param(request)
|
||||
token = params.get("token")
|
||||
if token:
|
||||
jwt_token = token
|
||||
else:
|
||||
token: str = request.headers.get('Authorization')
|
||||
if token is None:
|
||||
return await api_response(status=-1, error={
|
||||
"code": "missing-token",
|
||||
"message": "Missing token."
|
||||
}, http_status=401, request=request)
|
||||
token = token.replace("Bearer ", "")
|
||||
if token.startswith("sk_"):
|
||||
sk_token = token
|
||||
else:
|
||||
jwt_token = token
|
||||
|
||||
if sk_token is not None:
|
||||
if token not in config.AUTH_TOKENS:
|
||||
return await api_response(status=-1, error={
|
||||
"code": "token-invalid",
|
||||
"target": "token_id",
|
||||
"message": "Token invalid."
|
||||
}, http_status=401, request=request)
|
||||
|
||||
if "user_id" in params:
|
||||
request["user"] = params["user_id"]
|
||||
else:
|
||||
request["user"] = 0
|
||||
|
||||
request["caller"] = "server"
|
||||
elif jwt_token is not None:
|
||||
# Get appid from jwt header
|
||||
try:
|
||||
jwt_header = jwt.get_unverified_header(jwt_token)
|
||||
key_id: str = jwt_header["kid"]
|
||||
except (KeyError, jwt.exceptions.DecodeError):
|
||||
return await api_response(status=-1, error={
|
||||
"code": "token-invalid",
|
||||
"target": "token_id",
|
||||
"message": "Token issuer not exists."
|
||||
}, http_status=401, request=request)
|
||||
|
||||
# Check jwt
|
||||
try:
|
||||
data = jwt.decode(jwt_token, config.AUTH_TOKENS[key_id], algorithms=['HS256', 'HS384', 'HS512'])
|
||||
if "sub" not in data:
|
||||
return await api_response(status=-1, error={
|
||||
"code": "token-invalid",
|
||||
"target": "subject",
|
||||
"message": "Token subject invalid."
|
||||
}, http_status=401, request=request)
|
||||
|
||||
request["user"] = data["sub"]
|
||||
request["caller"] = "user"
|
||||
except (jwt.exceptions.DecodeError, jwt.exceptions.InvalidSignatureError, jwt.exceptions.InvalidAlgorithmError):
|
||||
return await api_response(status=-1, error={
|
||||
"code": "token-invalid",
|
||||
"target": "signature",
|
||||
"message": "Invalid signature."
|
||||
}, http_status=401, request=request)
|
||||
except (jwt.exceptions.ExpiredSignatureError):
|
||||
return await api_response(status=-1, error={
|
||||
"code": "token-invalid",
|
||||
"target": "expire",
|
||||
"message": "Token expired."
|
||||
}, http_status=401, request=request)
|
||||
except Exception as e:
|
||||
return await api_response(status=-1, error=str(e), http_status=500, request=request)
|
||||
else:
|
||||
return await api_response(status=-1, error={
|
||||
"code": "missing-token",
|
||||
"message": "Missing token."
|
||||
}, http_status=401, request=request)
|
||||
|
||||
return await f(*args, **kwargs)
|
||||
return async_wrapper(*args, **kwargs)
|
||||
return decorated_function
|
Loading…
Reference in New Issue