增加noawait,支持Azure API

master
落雨楓 2 years ago
parent e21a28a85f
commit 2f68357c1d

@ -0,0 +1,4 @@
#!/bin/sh
DIRNAME=`dirname $0`
cd $DIRNAME
./.venv/bin/activate

@ -2,80 +2,29 @@ import asyncio
import json import json
import time import time
import traceback import traceback
from local import noawait
from typing import Optional
from aiohttp import WSMsgType, web from aiohttp import WSMsgType, web
from sqlalchemy import select from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService from service.chat_complete import ChatCompleteService
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi from service.mediawiki_api import MediaWikiApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
import utils.web import utils.web
class ChatCompleteTaskList:
def __init__(self, dbs: DatabaseService):
self.on_message = None
self.chunks: list[str] = []
class ChatCompleteWebSocketController: async def run():
def __init__(self, request: web.Request): pass
self.request = request
self.ws = None
self.db = None
self.chat_complete = None
self.closed = False
self.refreshed_time = 0
async def run(self):
self.ws = web.WebSocketResponse()
await self.ws.prepare(self.request)
self.refreshed_time = time.time()
self.db = await DatabaseService.create(self.request.app)
self.query = self.request.query
if self.request.get("caller") == "user":
user_id = self.request.get("user")
else:
user_id = self.query.get("user_id")
title = self.query.get("title")
# create heartbeat task
asyncio.ensure_future(self._timeout_task())
async for msg in self.ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
event = data.get('event')
self.refreshed_time = time.time()
if event == 'chatcomplete':
asyncio.ensure_future(self._chatcomplete(data))
if event == 'ping':
await self.ws.send_json({
'event': 'pong'
})
except Exception as e:
print(e)
traceback.print_exc()
await self.ws.send_json({
'event': 'error',
'error': str(e)
})
elif msg.type == WSMsgType.ERROR:
print('ws connection closed with exception %s' %
self.ws.exception())
async def _timeout_task(self):
while not self.closed:
if time.time() - self.refreshed_time > 30:
self.closed = True
await self.ws.close()
return
await asyncio.sleep(1)
async def _chatcomplete(self, params: dict):
question = params.get("question")
conversation_id = params.get("conversation_id")
@noawait.wrap
async def start(self):
await self.run()
class ChatComplete: class ChatComplete:
@staticmethod @staticmethod
@ -244,8 +193,11 @@ class ChatComplete:
tokens = await tiktoken.get_tokens(question) tokens = await tiktoken.get_tokens(question)
transatcion_id = None transatcion_id = None
point_cost = 0
if request.get("caller") == "user": if request.get("caller") == "user":
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit) usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
transatcion_id = usage_res.get("transaction_id")
point_cost = usage_res.get("point_cost")
async def on_message(text: str): async def on_message(text: str):
# Send message to client, start with "+" to indicate it's a message # Send message to client, start with "+" to indicate it's a message
@ -261,8 +213,7 @@ class ChatComplete:
try: try:
chat_res = await chat_complete_service \ chat_res = await chat_complete_service \
.chat_complete(question, on_message, on_extracted_doc, .prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={
conversation_id=conversation_id, user_id=user_id, embedding_search={
"limit": extra_limit, "limit": extra_limit,
"in_collection": in_collection, "in_collection": in_collection,
}) })
@ -272,6 +223,8 @@ class ChatComplete:
**chat_res, **chat_res,
}) })
await chat_complete_service.set_latest_point_cost(point_cost)
if transatcion_id: if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"]) result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
except Exception as e: except Exception as e:

@ -31,7 +31,8 @@ class EmbeddingSearch:
if await embedding_search.should_update_page_index(): if await embedding_search.should_update_page_index():
if request.get("caller") == "user": if request.get("caller") == "user":
user_id = request.get("user") user_id = request.get("user")
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage") usage_res = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index() await embedding_search.prepare_update_index()
@ -107,7 +108,8 @@ class EmbeddingSearch:
if await embedding_search.should_update_page_index(): if await embedding_search.should_update_page_index():
if request.get("caller") == "user": if request.get("caller") == "user":
user_id = request.get("user") user_id = request.get("user")
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage") usage_res = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index() await embedding_search.prepare_update_index()

@ -1,36 +1,36 @@
from aiohttp import web from aiohttp import web
import utils.web import utils.web
import utils.text import utils.text
from extend.hangul_romanize import Transliter from extend.hangul_romanize import Transliter
from extend.hangul_romanize.rule import academic from extend.hangul_romanize.rule import academic
class Hanja: class Hanja:
@staticmethod @staticmethod
def convertToRomaja(self, hanja: str): def convertToRomaja(self, hanja: str):
transliter = Transliter(academic) transliter = Transliter(academic)
segList = utils.text.splitAscii(hanja) segList = utils.text.splitAscii(hanja)
sentenceList = [] sentenceList = []
for seg in segList: for seg in segList:
if seg == " ": if seg == " ":
sentenceList.append("-") sentenceList.append("-")
elif utils.text.isAscii(seg): elif utils.text.isAscii(seg):
if utils.text.isAsciiPunc(seg): if utils.text.isAsciiPunc(seg):
sentenceList.append(seg) sentenceList.append(seg)
else: else:
sentenceList.append([seg]) sentenceList.append([seg])
else: else:
roma = transliter.translit(seg) roma = transliter.translit(seg)
sentenceList.append(roma.split(" ")) sentenceList.append(roma.split(" "))
return sentenceList return sentenceList
@staticmethod @staticmethod
async def hanja2roma(request: web.Request): async def hanja2roma(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"sentence": { "sentence": {
"required": True, "required": True,
}, },
}) })
sentence = params.get('sentence') sentence = params.get('sentence')
data = Hanja.convertToRomaja(sentence) data = Hanja.convertToRomaja(sentence)
return await utils.web.api_response(1, data, request=request) return await utils.web.api_response(1, data, request=request)

@ -1,81 +1,81 @@
from __future__ import annotations from __future__ import annotations
from aiohttp import web from aiohttp import web
import os.path as path import os.path as path
import jieba import jieba
import jieba.posseg as pseg import jieba.posseg as pseg
from pypinyin import pinyin, Style from pypinyin import pinyin, Style
import utils.text import utils.text
import utils.web import utils.web
jieba.initialize() jieba.initialize()
userDict = path.dirname(path.dirname(path.dirname(__file__))) + "/data/userDict.txt" userDict = path.dirname(path.dirname(path.dirname(__file__))) + "/data/userDict.txt"
if path.exists(userDict): if path.exists(userDict):
jieba.load_userdict(userDict) jieba.load_userdict(userDict)
class Hanzi: class Hanzi:
@staticmethod @staticmethod
def filterJiebaTag(segList: list[str]): def filterJiebaTag(segList: list[str]):
ret = [] ret = []
for word, flag in segList: for word, flag in segList:
if flag[0] == "u" and (word == "" or word == ""): if flag[0] == "u" and (word == "" or word == ""):
ret.append("") ret.append("")
else: else:
ret.append(word) ret.append(word)
return ret return ret
@staticmethod @staticmethod
def convertToPinyin(sentence: str): def convertToPinyin(sentence: str):
sentence = utils.text.replaceCJKPunc(sentence).replace(' ', '-') sentence = utils.text.replaceCJKPunc(sentence).replace(' ', '-')
segList = Hanzi.filterJiebaTag(pseg.cut(sentence)) segList = Hanzi.filterJiebaTag(pseg.cut(sentence))
sentenceList = [] sentenceList = []
pinyinGroup = [] pinyinGroup = []
for seg in segList: for seg in segList:
if utils.text.isAscii(seg): if utils.text.isAscii(seg):
if utils.text.isAsciiPunc(seg): if utils.text.isAsciiPunc(seg):
if len(pinyinGroup) > 0: if len(pinyinGroup) > 0:
sentenceList.append(pinyinGroup) sentenceList.append(pinyinGroup)
pinyinGroup = [] pinyinGroup = []
sentenceList.append(seg) sentenceList.append(seg)
else: else:
if len(pinyinGroup) > 0: if len(pinyinGroup) > 0:
sentenceList.append(pinyinGroup) sentenceList.append(pinyinGroup)
pinyinGroup = [] pinyinGroup = []
sentenceList.append([seg]) sentenceList.append([seg])
else: else:
sentencePinyin = [] sentencePinyin = []
for one in pinyin(seg, style=Style.NORMAL): for one in pinyin(seg, style=Style.NORMAL):
sentencePinyin.append(one[0]) sentencePinyin.append(one[0])
pinyinGroup.append(sentencePinyin) pinyinGroup.append(sentencePinyin)
if len(pinyinGroup) > 0: if len(pinyinGroup) > 0:
sentenceList.append(pinyinGroup) sentenceList.append(pinyinGroup)
return sentenceList return sentenceList
@staticmethod @staticmethod
async def hanziToPinyin(request: web.Request): async def hanziToPinyin(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"sentence": { "sentence": {
"required": True, "required": True,
}, },
}) })
sentence = params.get('sentence') sentence = params.get('sentence')
data = Hanzi.convertToPinyin(sentence) data = Hanzi.convertToPinyin(sentence)
return await utils.web.api_response(1, data, request=request) return await utils.web.api_response(1, data, request=request)
@staticmethod @staticmethod
async def splitHanzi(request: web.Request): async def splitHanzi(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"sentence": { "sentence": {
"required": True, "required": True,
}, },
}) })
sentence = params.get("sentence") sentence = params.get("sentence")
segList = list(pseg.cut(sentence)) segList = list(pseg.cut(sentence))
data = [] data = []
for word, flag in segList: for word, flag in segList:
data.append({"word": word, "flag": flag}) data.append({"word": word, "flag": flag})
return await utils.web.api_response(1, data) return await utils.web.api_response(1, data)

@ -1,32 +1,32 @@
from aiohttp import web from aiohttp import web
import utils.web import utils.web
import utils.text import utils.text
from extend.kanji_to_romaji import kanji_to_romaji from extend.kanji_to_romaji import kanji_to_romaji
class Kanji: class Kanji:
@staticmethod @staticmethod
def convertToRomaji(self, kanji: str): def convertToRomaji(self, kanji: str):
segList = utils.text.splitAscii(kanji) segList = utils.text.splitAscii(kanji)
sentenceList = [] sentenceList = []
for seg in segList: for seg in segList:
if utils.text.isAscii(seg): if utils.text.isAscii(seg):
if utils.text.isAsciiPunc(seg): if utils.text.isAsciiPunc(seg):
sentenceList.append(seg) sentenceList.append(seg)
else: else:
sentenceList.append([seg]) sentenceList.append([seg])
else: else:
romaji = kanji_to_romaji(seg) romaji = kanji_to_romaji(seg)
sentenceList.append(romaji.split(" ")) sentenceList.append(romaji.split(" "))
return sentenceList return sentenceList
@staticmethod @staticmethod
async def kanji2romaji(request: web.Request): async def kanji2romaji(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"sentence": { "sentence": {
"required": True, "required": True,
}, },
}) })
sentence = params.get('sentence') sentence = params.get('sentence')
data = Kanji.convertToRomaji(sentence) data = Kanji.convertToRomaji(sentence)
return await utils.web.api_response(1, data, request=request) return await utils.web.api_response(1, data, request=request)

@ -13,7 +13,7 @@ class ConversationChunkModel(BaseModel):
__tablename__ = "chat_complete_conversation_chunk" __tablename__ = "chat_complete_conversation_chunk"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("chat_complete_conversation.id"), index=True) conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True)
message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0) tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True) updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True)

@ -1,12 +1,13 @@
from __future__ import annotations
import hashlib import hashlib
from typing import Optional from typing import Optional, Type
import asyncpg import asyncpg
from api.model.base import BaseModel from api.model.base import BaseModel
import config import config
import numpy as np import numpy as np
import sqlalchemy import sqlalchemy
from sqlalchemy import select, update, delete from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from pgvector.asyncpg import register_vector from pgvector.asyncpg import register_vector
@ -14,19 +15,38 @@ from pgvector.sqlalchemy import Vector
from service.database import DatabaseService from service.database import DatabaseService
class PageIndexModel(BaseModel): page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
class AbstractPageIndexModel(BaseModel):
__abstract__ = True __abstract__ = True
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
text: Mapped[str] = mapped_column(sqlalchemy.Text) text: Mapped[str] = mapped_column(sqlalchemy.Text)
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer) text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True) markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)
markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
def create_page_index_model(collection_id: int):
if collection_id in page_index_model_list:
return page_index_model_list[collection_id]
else:
class PageIndexModel(AbstractPageIndexModel):
__tablename__ = "embedding_search_page_index_%s" % str(collection_id)
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
page_index_model_list[collection_id] = PageIndexModel
return PageIndexModel
class PageIndexHelper: class PageIndexHelper:
columns = [ columns = [
"id", "id",
@ -60,12 +80,19 @@ class PageIndexHelper:
await register_vector(self.dbi) await register_vector(self.dbi)
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.orm = create_page_index_model(self.collection_id)
self.initialized = True self.initialized = True
return self return self
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb) await self.dbpool.__aexit__(exc_type, exc, tb)
await self.session.__aexit__(exc_type, exc, tb)
async def table_exists(self): async def table_exists(self):
exists = await self.dbi.fetchval("""SELECT EXISTS ( exists = await self.dbi.fetchval("""SELECT EXISTS (
@ -83,27 +110,31 @@ class PageIndexHelper:
# create table if not exists # create table if not exists
if not await self.table_exists(): if not await self.table_exists():
await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ ( async with self.dbs.engine.begin() as conn:
id SERIAL PRIMARY KEY, await conn.run_sync(self.orm.__table__.create)
page_id INTEGER NOT NULL,
sha1 VARCHAR(40) NOT NULL, # await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
text TEXT NOT NULL, # id SERIAL PRIMARY KEY,
text_len INTEGER NOT NULL, # page_id INTEGER NOT NULL,
embedding VECTOR(%d) NOT NULL, # sha1 VARCHAR(40) NOT NULL,
markdown TEXT NULL, # text TEXT NOT NULL,
markdown_len INTEGER NULL, # text_len INTEGER NOT NULL,
temp_doc_session_id INTEGER NULL # embedding VECTOR(%d) NOT NULL,
); # markdown TEXT NULL,
CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id); # markdown_len INTEGER NULL,
CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1); # temp_doc_session_id INTEGER NULL
CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id); # );
""" % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name)) # CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
self.table_initialized = False # CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
self.table_initialized = True
async def create_embedding_index(self): async def create_embedding_index(self):
await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);" pass
.replace("/*_*/", self.table_name)) # await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
# .replace("/*_*/", self.table_name))
def sha1_doc(self, doc: list): def sha1_doc(self, doc: list):
for item in doc: for item in doc:
@ -113,25 +144,20 @@ class PageIndexHelper:
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False): async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
indexed_sha1_list = [] indexed_sha1_list = []
sql = "SELECT sha1 FROM %s" % (self.table_name)
where = [] stmt = select(self.orm).column(self.orm.sha1)
params = []
if not with_temporary: if not with_temporary:
where.append("temp_doc_session_id IS NULL") stmt = stmt.where(self.orm.temp_doc_session_id == None)
if not in_collection: if not in_collection:
params.append(self.page_id) stmt = stmt.where(self.orm.page_id == self.page_id)
where.append("page_id = $%d" % len(params))
if len(where) > 0: ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
sql += " WHERE " + (" AND ".join(where))
ret = await self.dbi.fetch(sql, *params)
for row in ret: for row in ret:
indexed_sha1_list.append(row[0]) indexed_sha1_list.append(row.sha1)
return indexed_sha1_list return indexed_sha1_list
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True): async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):
@ -202,11 +228,11 @@ class PageIndexHelper:
if len(should_index) > 0: if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name), VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name),
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index]) [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index])
if len(should_persist) > 0: if len(should_persist) > 0:
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name), await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name),
[(self.page_id, sha1) for sha1 in should_persist]) [(self.page_id, sha1) for sha1 in should_persist])
if need_create_index: if need_create_index:
await self.create_embedding_index() await self.create_embedding_index()

@ -11,7 +11,7 @@ class TitleCollectionModel(BaseModel):
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
class TitleCollectionHelper: class TitleCollectionHelper:
def __init__(self, dbs: DatabaseService): def __init__(self, dbs: DatabaseService):
@ -29,7 +29,6 @@ class TitleCollectionHelper:
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb)
pass
async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]: async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]:
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title) stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)

@ -20,7 +20,11 @@ class TitleIndexModel(BaseModel):
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), index=True) embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
class TitleIndexHelper: class TitleIndexHelper:
__tablename__ = "embedding_search_title_index" __tablename__ = "embedding_search_title_index"

@ -1,66 +1,73 @@
PORT = 8144 PORT = 8144
HOST = "www.isekai.cn" HOST = "www.isekai.cn"
MW_API = "http://dev.isekai.cn/api.php" MW_API = "http://dev.isekai.cn/api.php"
DEBUG = True DEBUG = True
DATABASE = { DATABASE = {
"host": "127.0.0.1", "host": "127.0.0.1",
"database": "isekai_toolkit", "database": "isekai_toolkit",
"user": "", "user": "",
"password": "", "password": "",
"port": "5432", "port": "5432",
} }
EMBEDDING_VECTOR_SIZE = 1536 EMBEDDING_VECTOR_SIZE = 1536
OPENAI_API = "https://api.openai.com" OPENAI_API_TYPE = "openai" # openai or azure
OPENAI_TOKEN = "sk-" OPENAI_API = "https://api.openai.com"
OPENAI_TOKEN = "sk-"
CHATCOMPLETE_MAX_MEMORY_TOKENS = 1024 OPENAI_API = None
CHATCOMPLETE_MAX_INPUT_TOKENS = 768 OPENAI_TOKEN = ""
AZURE_OPENAI_ENDPOINT = "https://your-instance.openai.azure.com"
CHATCOMPLETE_OUTPUT_REPLACE = { AZURE_OPENAI_KEY = ""
"OpenAI": "オーペンエーアイ", AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME = ""
"ChatGPT": "チャットジーピーティー", AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME = ""
"GPT": "ジーピーティー",
"上下文": "消息", CHATCOMPLETE_MAX_MEMORY_TOKENS = 1024
"AI": "虛擬人物程序", CHATCOMPLETE_MAX_INPUT_TOKENS = 768
"语言模型": "虛擬人物程序",
"人工智能程序": "虛擬人物程序", CHATCOMPLETE_OUTPUT_REPLACE = {
"語言模型": "虛擬人物程序", "OpenAI": "オーペンエーアイ",
"人工智能程式": "虛擬人物程序", "ChatGPT": "チャットジーピーティー",
} "GPT": "ジーピーティー",
"上下文": "消息",
CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE = "无标题" "AI": "虛擬人物程序",
"语言模型": "虛擬人物程序",
CHATCOMPLETE_BOT_NAME = "寫作助手" "人工智能程序": "虛擬人物程序",
"語言模型": "虛擬人物程序",
PROMPTS = { "人工智能程式": "虛擬人物程序",
"chat": { }
"system_prompt": "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel.",
}, CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE = "无标题"
"title": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion.", CHATCOMPLETE_BOT_NAME = "寫作助手"
"prompt": "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}"
}, PROMPTS = {
"suggestions": { "chat": {
"prompt": "根據下面的對話,提出幾個問題:\n\n{content}" "system_prompt": "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel.",
}, },
"summary": { "title": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion. Output in Chinese.", "system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion.",
"prompt": "為“{bot_name}”概括下面的聊天記錄排除不重要的對話不要表明自己的意見儘量簡潔。使用中文輸出“User”是同一個人。\n\n{content}" "prompt": "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}"
}, },
"extracted_doc": { "suggestions": {
"prompt": "Here are some relevant informations:\n\n{content}" "prompt": "根據下面的對話,提出幾個問題:\n\n{content}"
} },
} "summary": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion. Output in Chinese.",
REQUEST_PROXY = "http://127.0.0.1:7890" "prompt": "為“{bot_name}”概括下面的聊天記錄排除不重要的對話不要表明自己的意見儘量簡潔。使用中文輸出“User”是同一個人。\n\n{content}"
},
AUTH_TOKENS = { "extracted_doc": {
"isekaiwiki": "sk-123456" "prompt": "Here are some relevant informations:\n\n{content}"
} }
}
MW_BOT_LOGIN_USERNAME = "Hyperzlib@ChatComplete"
REQUEST_PROXY = "http://127.0.0.1:7890"
AUTH_TOKENS = {
"isekaiwiki": "sk-123456"
}
MW_BOT_LOGIN_USERNAME = "Hyperzlib@ChatComplete"
MW_BOT_LOGIN_PASSWORD = "" MW_BOT_LOGIN_PASSWORD = ""

@ -1,2 +1,2 @@
from .core import Transliter # noqa from .core import Transliter # noqa

@ -1,89 +1,89 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
try: try:
unicode(0) unicode(0)
except NameError: except NameError:
# py3 # py3
unicode = str unicode = str
unichr = chr unichr = chr
class Syllable(object): class Syllable(object):
"""Hangul syllable interface""" """Hangul syllable interface"""
MIN = ord('') MIN = ord('')
MAX = ord('') MAX = ord('')
def __init__(self, char=None, code=None): def __init__(self, char=None, code=None):
if char is None and code is None: if char is None and code is None:
raise TypeError('__init__ takes char or code as a keyword argument (not given)') raise TypeError('__init__ takes char or code as a keyword argument (not given)')
if char is not None and code is not None: if char is not None and code is not None:
raise TypeError('__init__ takes char or code as a keyword argument (both given)') raise TypeError('__init__ takes char or code as a keyword argument (both given)')
if char: if char:
code = ord(char) code = ord(char)
if not self.MIN <= code <= self.MAX: if not self.MIN <= code <= self.MAX:
raise TypeError('__init__ expected Hangul syllable but {0} not in [{1}..{2}]'.format(code, self.MIN, self.MAX)) raise TypeError('__init__ expected Hangul syllable but {0} not in [{1}..{2}]'.format(code, self.MIN, self.MAX))
self.code = code self.code = code
@property @property
def index(self): def index(self):
return self.code - self.MIN return self.code - self.MIN
@property @property
def initial(self): def initial(self):
return self.index // 588 return self.index // 588
@property @property
def vowel(self): def vowel(self):
return (self.index // 28) % 21 return (self.index // 28) % 21
@property @property
def final(self): def final(self):
return self.index % 28 return self.index % 28
@property @property
def char(self): def char(self):
return unichr(self.code) return unichr(self.code)
def __unicode__(self): def __unicode__(self):
return self.char return self.char
def __repr__(self): def __repr__(self):
return '''<Syllable({}({}),{}({}),{}({}),{}({}))>'''.format( return '''<Syllable({}({}),{}({}),{}({}),{}({}))>'''.format(
self.code, self.char, self.initial, '', self.vowel, '', self.final, '') self.code, self.char, self.initial, '', self.vowel, '', self.final, '')
class Transliter(object): class Transliter(object):
"""General transliting interface""" """General transliting interface"""
def __init__(self, rule): def __init__(self, rule):
self.rule = rule self.rule = rule
def translit(self, text): def translit(self, text):
"""Translit text to romanized text """Translit text to romanized text
:param text: Unicode string or unicode character iterator :param text: Unicode string or unicode character iterator
""" """
result = [] result = []
pre = None, None pre = None, None
now = None, None now = None, None
for c in text: for c in text:
try: try:
post = c, Syllable(c) post = c, Syllable(c)
except TypeError: except TypeError:
post = c, None post = c, None
if now[0] is not None: if now[0] is not None:
out = self.rule(now, pre=pre, post=post) out = self.rule(now, pre=pre, post=post)
if out is not None: if out is not None:
result.append(out) result.append(out)
pre = now pre = now
now = post now = post
if now is not None: if now is not None:
out = self.rule(now, pre=pre, post=(None, None)) out = self.rule(now, pre=pre, post=(None, None))
if out is not None: if out is not None:
result.append(out) result.append(out)
return ''.join(result) return ''.join(result)

@ -1,47 +1,47 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
REVISED_INITIALS = 'g', 'kk', 'n', 'd', 'tt', 'l', 'm', 'b', 'pp', 's', 'ss', '', 'j', 'jj', 'ch', 'k', 't', 'p', 'h' REVISED_INITIALS = 'g', 'kk', 'n', 'd', 'tt', 'l', 'm', 'b', 'pp', 's', 'ss', '', 'j', 'jj', 'ch', 'k', 't', 'p', 'h'
REVISED_VOWELS = 'a', 'ae', 'ya', 'yae', 'eo', 'e', 'yeo', 'ye', 'o', 'wa', 'wae', 'oe', 'yo', 'u', 'wo', 'we', 'wi', 'yu', 'eu', 'ui', 'i' REVISED_VOWELS = 'a', 'ae', 'ya', 'yae', 'eo', 'e', 'yeo', 'ye', 'o', 'wa', 'wae', 'oe', 'yo', 'u', 'wo', 'we', 'wi', 'yu', 'eu', 'ui', 'i'
REVISED_FINALS = '', 'g', 'kk', 'gs', 'n', 'nj', 'nh', 'd', 'l', 'lg', 'lm', 'lb', 'ls', 'lt', 'lp', 'lh', 'm', 'b', 'bs', 's', 'ss', 'ng', 'j', 'ch', 'k', 't', 'p', 'h' REVISED_FINALS = '', 'g', 'kk', 'gs', 'n', 'nj', 'nh', 'd', 'l', 'lg', 'lm', 'lb', 'ls', 'lt', 'lp', 'lh', 'm', 'b', 'bs', 's', 'ss', 'ng', 'j', 'ch', 'k', 't', 'p', 'h'
def academic_ambiguous_patterns(): def academic_ambiguous_patterns():
import itertools import itertools
result = set() result = set()
for final, initial in itertools.product(REVISED_FINALS, REVISED_INITIALS): for final, initial in itertools.product(REVISED_FINALS, REVISED_INITIALS):
check = False check = False
combined = final + initial combined = final + initial
for i in range(len(combined)): for i in range(len(combined)):
head, tail = combined[:i], combined[i:] head, tail = combined[:i], combined[i:]
if head in REVISED_FINALS and tail in REVISED_INITIALS: if head in REVISED_FINALS and tail in REVISED_INITIALS:
if not check: if not check:
check = True check = True
else: else:
result.add(combined) result.add(combined)
break break
return result return result
ACADEMIC_AMBIGUOUS_PATTERNS = academic_ambiguous_patterns() ACADEMIC_AMBIGUOUS_PATTERNS = academic_ambiguous_patterns()
def academic(now, pre, **options): def academic(now, pre, **options):
"""Rule for academic translition.""" """Rule for academic translition."""
c, s = now c, s = now
if not s: if not s:
return c return c
ps = pre[1] if pre else None ps = pre[1] if pre else None
marker = False marker = False
if ps: if ps:
if s.initial == 11: if s.initial == 11:
marker = True marker = True
elif ps and (REVISED_FINALS[ps.final] + REVISED_INITIALS[s.initial]) in ACADEMIC_AMBIGUOUS_PATTERNS: elif ps and (REVISED_FINALS[ps.final] + REVISED_INITIALS[s.initial]) in ACADEMIC_AMBIGUOUS_PATTERNS:
marker = True marker = True
r = u'' r = u''
if marker: if marker:
r += '-' r += '-'
r += REVISED_INITIALS[s.initial] + REVISED_VOWELS[s.vowel] + REVISED_FINALS[s.final] r += REVISED_INITIALS[s.initial] + REVISED_VOWELS[s.vowel] + REVISED_FINALS[s.final]
return r return r

@ -1,5 +1,5 @@
from .kanji_to_romaji_module import convert_hiragana_to_katakana, translate_to_romaji, translate_soukon, \ from .kanji_to_romaji_module import convert_hiragana_to_katakana, translate_to_romaji, translate_soukon, \
translate_long_vowel, translate_soukon_ch, kanji_to_romaji translate_long_vowel, translate_soukon_ch, kanji_to_romaji
__all__ = ["load_mappings_dict", "convert_hiragana_to_katakana", "convert_katakana_to_hiragana", __all__ = ["load_mappings_dict", "convert_hiragana_to_katakana", "convert_katakana_to_hiragana",
"translate_to_romaji", "translate_soukon", "translate_to_romaji", "translate_soukon",
"translate_long_vowel", "translate_soukon_ch", "kanji_to_romaji"] "translate_long_vowel", "translate_soukon_ch", "kanji_to_romaji"]

File diff suppressed because it is too large Load Diff

@ -1,154 +1,154 @@
{ {
"": "!", "": "!",
"": "\"", "": "\"",
"": "#", "": "#",
"": "$", "": "$",
"": "%", "": "%",
"": "&", "": "&",
"": "'", "": "'",
"": "*", "": "*",
"": "+", "": "+",
"": ",", "": ",",
"": "-", "": "-",
"": ".", "": ".",
"": "\/", "": "\/",
"": "0", "": "0",
"": "1", "": "1",
"": "2", "": "2",
"": "3", "": "3",
"": "4", "": "4",
"": "5", "": "5",
"": "6", "": "6",
"": "7", "": "7",
"": "8", "": "8",
"": "9", "": "9",
"": ":", "": ":",
"": ";", "": ";",
"": "<", "": "<",
"": "=", "": "=",
"": ">", "": ">",
"": "?", "": "?",
"": "@", "": "@",
"": "A", "": "A",
"": "B", "": "B",
"": "C", "": "C",
"": "D", "": "D",
"": "E", "": "E",
"": "F", "": "F",
"": "G", "": "G",
"": "H", "": "H",
"": "I", "": "I",
"": "J", "": "J",
"": "K", "": "K",
"": "L", "": "L",
"": "M", "": "M",
"": "N", "": "N",
"И": "N", "И": "N",
"": "O", "": "O",
"": "P", "": "P",
"": "Q", "": "Q",
"": "R", "": "R",
"": "S", "": "S",
"": "T", "": "T",
"": "U", "": "U",
"": "V", "": "V",
"": "W", "": "W",
"": "X", "": "X",
"": "Y", "": "Y",
"": "Z", "": "Z",
"": "[", "": "[",
"": "\\", "": "\\",
"": "]", "": "]",
"": "^", "": "^",
"_": "_", "_": "_",
"": "'", "": "'",
"": "a", "": "a",
"": "b", "": "b",
"": "c", "": "c",
"": "d", "": "d",
"": "e", "": "e",
"": "f", "": "f",
"": "g", "": "g",
"": "h", "": "h",
"": "i", "": "i",
"": "j", "": "j",
"": "k", "": "k",
"": "l", "": "l",
"": "m", "": "m",
"": "n", "": "n",
"": "o", "": "o",
"": "p", "": "p",
"": "q", "": "q",
"": "r", "": "r",
"": "s", "": "s",
"": "t", "": "t",
"": "u", "": "u",
"": "v", "": "v",
"": "w", "": "w",
"": "x", "": "x",
"": "y", "": "y",
"": "z", "": "z",
"": "{", "": "{",
"": "|", "": "|",
"": "}", "": "}",
"": "~", "": "~",
"Ā": "A", "Ā": "A",
"Ă": "A", "Ă": "A",
"Ą": "A", "Ą": "A",
"â": "a", "â": "a",
"ā": "a", "ā": "a",
"ă": "a", "ă": "a",
"ą": "a", "ą": "a",
"Ē": "E", "Ē": "E",
"Ĕ": "E", "Ĕ": "E",
"Ė": "E", "Ė": "E",
"Ę": "E", "Ę": "E",
"Ě": "E", "Ě": "E",
"ē": "e", "ē": "e",
"ĕ": "e", "ĕ": "e",
"ė": "e", "ė": "e",
"ę": "e", "ę": "e",
"ě": "e", "ě": "e",
"Ī": "I", "Ī": "I",
"Ĭ": "I", "Ĭ": "I",
"Į": "I", "Į": "I",
"ī": "i", "ī": "i",
"ĭ": "i", "ĭ": "i",
"į": "i", "į": "i",
"Ō": "O", "Ō": "O",
"Ŏ": "O", "Ŏ": "O",
"Ő": "O", "Ő": "O",
"ō": "o", "ō": "o",
"ŏ": "o", "ŏ": "o",
"ő": "o", "ő": "o",
"Ũ": "U", "Ũ": "U",
"Ū": "U", "Ū": "U",
"Ŭ": "U", "Ŭ": "U",
"Ů": "U", "Ů": "U",
"Ű": "U", "Ű": "U",
"Ų": "U", "Ų": "U",
"ũ": "u", "ũ": "u",
"ū": "u", "ū": "u",
"ŭ": "u", "ŭ": "u",
"ů": "u", "ů": "u",
"ű": "u", "ű": "u",
"ų": "u", "ų": "u",
"Ӓ": "A", "Ӓ": "A",
"Ӑ": "A", "Ӑ": "A",
"Ѐ": "E", "Ѐ": "E",
"Ё": "E", "Ё": "E",
"Ӧ": "O", "Ӧ": "O",
"ӓ": "a", "ӓ": "a",
"ӑ": "a", "ӑ": "a",
"ѐ": "e", "ѐ": "e",
"ё": "e", "ё": "e",
"ӧ": "o", "ӧ": "o",
"ω": "w", "ω": "w",
"€": "E", "€": "E",
"∃": "E", "∃": "E",
"ϛ": "c" "ϛ": "c"
} }

@ -1,120 +1,120 @@
{ {
"ぁ": "a", "ぁ": "a",
"あ": "a", "あ": "a",
"ぃ": "i", "ぃ": "i",
"い": "i", "い": "i",
"ぅ": "u", "ぅ": "u",
"う": "u", "う": "u",
"ぇ": "e", "ぇ": "e",
"え": "e", "え": "e",
"ぉ": "o", "ぉ": "o",
"お": "o", "お": "o",
"か": "ka", "か": "ka",
"が": "ga", "が": "ga",
"き": "ki", "き": "ki",
"きゃ": "kya", "きゃ": "kya",
"きゅ": "kyu", "きゅ": "kyu",
"きょ": "kyo", "きょ": "kyo",
"ぎ": "gi", "ぎ": "gi",
"ぎゃ": "gya", "ぎゃ": "gya",
"ぎゅ": "gyu", "ぎゅ": "gyu",
"ぎょ": "gyo", "ぎょ": "gyo",
"く": "ku", "く": "ku",
"ぐ": "gu", "ぐ": "gu",
"け": "ke", "け": "ke",
"げ": "ge", "げ": "ge",
"こ": "ko", "こ": "ko",
"ご": "go", "ご": "go",
"さ": "sa", "さ": "sa",
"ざ": "za", "ざ": "za",
"し": "shi", "し": "shi",
"しゃ": "sha", "しゃ": "sha",
"しゅ": "shu", "しゅ": "shu",
"しょ": "sho", "しょ": "sho",
"じ": "ji", "じ": "ji",
"じゃ": "ja", "じゃ": "ja",
"じゅ": "ju", "じゅ": "ju",
"じょ": "jo", "じょ": "jo",
"す": "su", "す": "su",
"ず": "zu", "ず": "zu",
"せ": "se", "せ": "se",
"ぜ": "ze", "ぜ": "ze",
"そ": "so", "そ": "so",
"ぞ": "zo", "ぞ": "zo",
"た": "ta", "た": "ta",
"だ": "da", "だ": "da",
"ち": "chi", "ち": "chi",
"ちゃ": "cha", "ちゃ": "cha",
"ちゅ": "chu", "ちゅ": "chu",
"ちょ": "cho", "ちょ": "cho",
"ぢ": "ji", "ぢ": "ji",
"つ": "tsu", "つ": "tsu",
"づ": "zu", "づ": "zu",
"て": "te", "て": "te",
"で": "de", "で": "de",
"と": "to", "と": "to",
"ど": "do", "ど": "do",
"な": "na", "な": "na",
"に": "ni", "に": "ni",
"にゃ": "nya", "にゃ": "nya",
"にゅ": "nyu", "にゅ": "nyu",
"にょ": "nyo", "にょ": "nyo",
"ぬ": "nu", "ぬ": "nu",
"ね": "ne", "ね": "ne",
"の": "no", "の": "no",
"は": "ha", "は": "ha",
"ば": "ba", "ば": "ba",
"ぱ": "pa", "ぱ": "pa",
"ひ": "hi", "ひ": "hi",
"ひゃ": "hya", "ひゃ": "hya",
"ひゅ": "hyu", "ひゅ": "hyu",
"ひょ": "hyo", "ひょ": "hyo",
"び": "bi", "び": "bi",
"びゃ": "bya", "びゃ": "bya",
"びゅ": "byu", "びゅ": "byu",
"びょ": "byo", "びょ": "byo",
"ぴ": "pi", "ぴ": "pi",
"ぴゃ": "pya", "ぴゃ": "pya",
"ぴゅ": "pyu", "ぴゅ": "pyu",
"ぴょ": "pyo", "ぴょ": "pyo",
"ふ": "fu", "ふ": "fu",
"ぶ": "bu", "ぶ": "bu",
"ぷ": "pu", "ぷ": "pu",
"へ": "he", "へ": "he",
"べ": "be", "べ": "be",
"ぺ": "pe", "ぺ": "pe",
"ほ": "ho", "ほ": "ho",
"ぼ": "bo", "ぼ": "bo",
"ぽ": "po", "ぽ": "po",
"ま": "ma", "ま": "ma",
"み": "mi", "み": "mi",
"みゃ": "mya", "みゃ": "mya",
"みゅ": "myu", "みゅ": "myu",
"みょ": "myo", "みょ": "myo",
"む": "mu", "む": "mu",
"め": "me", "め": "me",
"も": "mo", "も": "mo",
"や": "ya", "や": "ya",
"ゆ": "yu", "ゆ": "yu",
"よ": "yo", "よ": "yo",
"ら": "ra", "ら": "ra",
"り": "ri", "り": "ri",
"りゃ": "rya", "りゃ": "rya",
"りゅ": "ryu", "りゅ": "ryu",
"りょ": "ryo", "りょ": "ryo",
"る": "ru", "る": "ru",
"れ": "re", "れ": "re",
"ろ": "ro", "ろ": "ro",
"ゎ": "wa", "ゎ": "wa",
"わ": "wa", "わ": "wa",
"ゐ": "wi", "ゐ": "wi",
"ゑ": "we", "ゑ": "we",
"を": " wo ", "を": " wo ",
"ん": "n", "ん": "n",
"ゔ": "vu", "ゔ": "vu",
"ゕ": "ka", "ゕ": "ka",
"ゖ": "ke", "ゖ": "ke",
"ゝ": "iteration_mark", "ゝ": "iteration_mark",
"ゞ": "voiced_iteration_mark", "ゞ": "voiced_iteration_mark",
"ゟ": "yori" "ゟ": "yori"
} }

File diff suppressed because it is too large Load Diff

@ -1,18 +1,18 @@
{ {
"今日": { "今日": {
"w_type": "noun", "w_type": "noun",
"romaji": "kyou" "romaji": "kyou"
}, },
"明日": { "明日": {
"w_type": "noun", "w_type": "noun",
"romaji": "ashita" "romaji": "ashita"
}, },
"本": { "本": {
"w_type": "noun", "w_type": "noun",
"romaji": "hon" "romaji": "hon"
}, },
"中": { "中": {
"w_type": "noun", "w_type": "noun",
"romaji": "naka" "romaji": "naka"
} }
} }

@ -1,78 +1,78 @@
{ {
"朝日奈丸佳": { "朝日奈丸佳": {
"w_type": "noun", "w_type": "noun",
"romaji": "Asahina Madoka" "romaji": "Asahina Madoka"
}, },
"高海千歌": { "高海千歌": {
"w_type": "noun", "w_type": "noun",
"romaji": "Takami Chika" "romaji": "Takami Chika"
}, },
"鏡音レン": { "鏡音レン": {
"w_type": "noun", "w_type": "noun",
"romaji": "Kagamine Len" "romaji": "Kagamine Len"
}, },
"鏡音リン": { "鏡音リン": {
"w_type": "noun", "w_type": "noun",
"romaji": "Kagamine Rin" "romaji": "Kagamine Rin"
}, },
"逢坂大河": { "逢坂大河": {
"w_type": "noun", "w_type": "noun",
"romaji": "Aisaka Taiga" "romaji": "Aisaka Taiga"
}, },
"水樹奈々": { "水樹奈々": {
"w_type": "noun", "w_type": "noun",
"romaji": "Mizuki Nana" "romaji": "Mizuki Nana"
}, },
"桜内梨子": { "桜内梨子": {
"w_type": "noun", "w_type": "noun",
"romaji": "Sakurauchi Riko" "romaji": "Sakurauchi Riko"
}, },
"山吹沙綾": { "山吹沙綾": {
"w_type": "noun", "w_type": "noun",
"romaji": "Yamabuki Saaya" "romaji": "Yamabuki Saaya"
}, },
"初音ミク": { "初音ミク": {
"w_type": "noun", "w_type": "noun",
"romaji": "Hatsune Miku" "romaji": "Hatsune Miku"
}, },
"渡辺曜": { "渡辺曜": {
"w_type": "noun", "w_type": "noun",
"romaji": "Watanabe You" "romaji": "Watanabe You"
}, },
"原由実": { "原由実": {
"w_type": "noun", "w_type": "noun",
"romaji": "Hara Yumi" "romaji": "Hara Yumi"
}, },
"北宇治": { "北宇治": {
"w_type": "noun", "w_type": "noun",
"romaji": "Kita Uji" "romaji": "Kita Uji"
}, },
"六本木": { "六本木": {
"w_type": "noun", "w_type": "noun",
"romaji": "Roppongi" "romaji": "Roppongi"
}, },
"久美子": { "久美子": {
"w_type": "noun", "w_type": "noun",
"romaji": "Kumiko" "romaji": "Kumiko"
}, },
"政宗": { "政宗": {
"w_type": "noun", "w_type": "noun",
"romaji": "Masamune" "romaji": "Masamune"
}, },
"小林": { "小林": {
"w_type": "noun", "w_type": "noun",
"romaji": "Kobayashi" "romaji": "Kobayashi"
}, },
"奥寺": { "奥寺": {
"w_type": "noun", "w_type": "noun",
"romaji": "Okudera" "romaji": "Okudera"
}, },
"佐藤": { "佐藤": {
"w_type": "noun", "w_type": "noun",
"romaji": "Satou" "romaji": "Satou"
}, },
"玲子": { "玲子": {
"w_type": "noun", "w_type": "noun",
"romaji": "Reiko" "romaji": "Reiko"
} }
} }

@ -1,159 +1,159 @@
{ {
"ァ": "a", "ァ": "a",
"ア": "a", "ア": "a",
"ィ": "i", "ィ": "i",
"イ": "i", "イ": "i",
"イィ": "yi", "イィ": "yi",
"イェ": "ye", "イェ": "ye",
"ゥ": "u", "ゥ": "u",
"ウ": "u", "ウ": "u",
"ウィ": "wi", "ウィ": "wi",
"ウェ": "we", "ウェ": "we",
"ウォ": "wo", "ウォ": "wo",
"ェ": "e", "ェ": "e",
"エ": "e", "エ": "e",
"ォ": "o", "ォ": "o",
"オ": "o", "オ": "o",
"カ": "ka", "カ": "ka",
"ガ": "ga", "ガ": "ga",
"キ": "ki", "キ": "ki",
"キェ": "kye", "キェ": "kye",
"キャ": "kya", "キャ": "kya",
"キュ": "kyu", "キュ": "kyu",
"キョ": "kyo", "キョ": "kyo",
"ギ": "gi", "ギ": "gi",
"ギェ": "gye", "ギェ": "gye",
"ギャ": "gya", "ギャ": "gya",
"ギュ": "gyu", "ギュ": "gyu",
"ギョ": "gyo", "ギョ": "gyo",
"ク": "ku", "ク": "ku",
"クァ": "kwa", "クァ": "kwa",
"クィ": "kwi", "クィ": "kwi",
"クェ": "kwe", "クェ": "kwe",
"クォ": "kwo", "クォ": "kwo",
"グ": "gu", "グ": "gu",
"グァ": "gwa", "グァ": "gwa",
"グィ": "gwi", "グィ": "gwi",
"グェ": "gwe", "グェ": "gwe",
"グォ": "gwo", "グォ": "gwo",
"ケ": "ke", "ケ": "ke",
"ゲ": "ge", "ゲ": "ge",
"コ": "ko", "コ": "ko",
"ゴ": "go", "ゴ": "go",
"サ": "sa", "サ": "sa",
"ザ": "za", "ザ": "za",
"シ": "shi", "シ": "shi",
"シェ": "she", "シェ": "she",
"シャ": "sha", "シャ": "sha",
"シュ": "shu", "シュ": "shu",
"ショ": "sho", "ショ": "sho",
"ジ": "ji", "ジ": "ji",
"ジェ": "je", "ジェ": "je",
"ジャ": "ja", "ジャ": "ja",
"ジュ": "ju", "ジュ": "ju",
"ジョ": "jo", "ジョ": "jo",
"ス": "su", "ス": "su",
"スィ": "si", "スィ": "si",
"ズ": "zu", "ズ": "zu",
"ズィ": "zi", "ズィ": "zi",
"セ": "se", "セ": "se",
"ゼ": "ze", "ゼ": "ze",
"ソ": "so", "ソ": "so",
"ゾ": "zo", "ゾ": "zo",
"タ": "ta", "タ": "ta",
"ダ": "da", "ダ": "da",
"チ": "chi", "チ": "chi",
"チェ": "che", "チェ": "che",
"チャ": "cha", "チャ": "cha",
"チュ": "chu", "チュ": "chu",
"チョ": "cho", "チョ": "cho",
"ヂ": "ji", "ヂ": "ji",
"ツ": "tsu", "ツ": "tsu",
"ツァ": "tsa", "ツァ": "tsa",
"ツィ": "tsi", "ツィ": "tsi",
"ツェ": "tse", "ツェ": "tse",
"ツォ": "tso", "ツォ": "tso",
"ヅ": "zu", "ヅ": "zu",
"テ": "te", "テ": "te",
"ティ": "ti", "ティ": "ti",
"デ": "de", "デ": "de",
"ディ": "di", "ディ": "di",
"ト": "to", "ト": "to",
"トゥ": "tu", "トゥ": "tu",
"ド": "do", "ド": "do",
"ドゥ": "du", "ドゥ": "du",
"ナ": "na", "ナ": "na",
"ニ": "ni", "ニ": "ni",
"ニャ": "nya", "ニャ": "nya",
"ニュ": "nyu", "ニュ": "nyu",
"ニョ": "nyo", "ニョ": "nyo",
"ヌ": "nu", "ヌ": "nu",
"ネ": "ne", "ネ": "ne",
"": "no", "": "no",
"ハ": "ha", "ハ": "ha",
"バ": "ba", "バ": "ba",
"パ": "pa", "パ": "pa",
"ヒ": "hi", "ヒ": "hi",
"ヒャ": "hya", "ヒャ": "hya",
"ヒュ": "hyu", "ヒュ": "hyu",
"ヒョ": "hyo", "ヒョ": "hyo",
"ビ": "bi", "ビ": "bi",
"ビャ": "bya", "ビャ": "bya",
"ビュ": "byu", "ビュ": "byu",
"ビョ": "byo", "ビョ": "byo",
"ピ": "pi", "ピ": "pi",
"ピャ": "pya", "ピャ": "pya",
"ピュ": "pyu", "ピュ": "pyu",
"ピョ": "pyo", "ピョ": "pyo",
"フ": "fu", "フ": "fu",
"ファ": "fa", "ファ": "fa",
"フィ": "fi", "フィ": "fi",
"フェ": "fe", "フェ": "fe",
"フォ": "fo", "フォ": "fo",
"ブ": "bu", "ブ": "bu",
"プ": "pu", "プ": "pu",
"ヘ": "he", "ヘ": "he",
"ベ": "be", "ベ": "be",
"ペ": "pe", "ペ": "pe",
"ホ": "ho", "ホ": "ho",
"ホゥ": "hu", "ホゥ": "hu",
"ボ": "bo", "ボ": "bo",
"ポ": "po", "ポ": "po",
"マ": "ma", "マ": "ma",
"ミ": "mi", "ミ": "mi",
"ミャ": "mya", "ミャ": "mya",
"ミュ": "myu", "ミュ": "myu",
"ミョ": "myo", "ミョ": "myo",
"ム": "mu", "ム": "mu",
"メ": "me", "メ": "me",
"モ": "mo", "モ": "mo",
"ヤ": "ya", "ヤ": "ya",
"ユ": "yu", "ユ": "yu",
"ヨ": "yo", "ヨ": "yo",
"ラ": "ra", "ラ": "ra",
"リ": "ri", "リ": "ri",
"リャ": "rya", "リャ": "rya",
"リュ": "ryu", "リュ": "ryu",
"リョ": "ryo", "リョ": "ryo",
"ル": "ru", "ル": "ru",
"レ": "re", "レ": "re",
"ロ": "ro", "ロ": "ro",
"ヮ": "wa", "ヮ": "wa",
"ワ": "wa", "ワ": "wa",
"ヰ": "wi", "ヰ": "wi",
"ヱ": "we", "ヱ": "we",
"ヲ": "wo", "ヲ": "wo",
"ン": "n", "ン": "n",
"ヴ": "vu", "ヴ": "vu",
"ヴァ": "va", "ヴァ": "va",
"ヴィ": "vi", "ヴィ": "vi",
"ヴェ": "ve", "ヴェ": "ve",
"ヴォ": "vo", "ヴォ": "vo",
"ヵ": "ka", "ヵ": "ka",
"ヶ": "ke", "ヶ": "ke",
"ヺ": "vo", "ヺ": "vo",
"・": " ", "・": " ",
"ヽ": "iteration_mark", "ヽ": "iteration_mark",
"ヾ": "voiced_iteration_mark", "ヾ": "voiced_iteration_mark",
"ヿ": "koto" "ヿ": "koto"
} }

File diff suppressed because it is too large Load Diff

@ -1,103 +1,103 @@
{ {
"\u200b": "", "\u200b": "",
"「": "[", "「": "[",
"」": "]", "」": "]",
"『": "[", "『": "[",
"』": "]", "』": "]",
"": "(", "": "(",
"": ")", "": ")",
"": "[", "": "[",
"": "]", "": "]",
"": "{", "": "{",
"": "}", "": "}",
"〈": "(", "〈": "(",
"〉": ")", "〉": ")",
"【": "[", "【": "[",
"】": "]", "】": "]",
"": "[", "": "[",
"": "]", "": "]",
"〖": "[", "〖": "[",
"〗": "]", "〗": "]",
"〘": "[", "〘": "[",
"〙": "]", "〙": "]",
"〚": "[", "〚": "[",
"〛": "]", "〛": "]",
"": "--", "": "--",
"〓": "-", "〓": "-",
"": "=", "": "=",
"〜": "~", "〜": "~",
"…": "_", "…": "_",
"※": "", "※": "",
"♪": "", "♪": "",
"♫": "", "♫": "",
"♬": "", "♬": "",
"♩": "", "♩": "",
"": "!", "": "!",
"": "?", "": "?",
"、": ",", "、": ",",
"♥": " ", "♥": " ",
"«": "(", "«": "(",
"»": ")", "»": ")",
"≪": "(", "≪": "(",
"≫": ")", "≫": ")",
"": "-", "": "-",
"”": "", "”": "",
"“": "", "“": "",
"゙": "", "゙": "",
"": "'", "": "'",
"": "", "": "",
"→": "", "→": "",
"⇒": "", "⇒": "",
"∞": " ", "∞": " ",
"☆": " ", "☆": " ",
"♠": " ", "♠": " ",
"ᷨ": " ", "ᷨ": " ",
"ꯑ": " ", "ꯑ": " ",
"ᤙ": " ", "ᤙ": " ",
"": " ", "": " ",
"△": "" , "△": "" ,
"★": " ", "★": " ",
"♡": " ", "♡": " ",
"。": "", "。": "",
"゚": "", "゚": "",
"(": "(", "(": "(",
")": ")", ")": ")",
"∀": "a", "∀": "a",
"ά": "a", "ά": "a",
"ɪ": "I", "ɪ": "I",
"˥": "l", "˥": "l",
"゚": "", "゚": "",
"—": "-", "—": "-",
"Я": "", "Я": "",
"Ψ": "", "Ψ": "",
"┐": "", "┐": "",
"ə": "", "ə": "",
"ˈ": "", "ˈ": "",
"×": " x ", "×": " x ",
"†": "", "†": "",
"≡": " ", "≡": " ",
"": "", "": "",
"": "-", "": "-",
"⇔": " ", "⇔": " ",
"≒": " ", "≒": " ",
"〆": "shime", "〆": "shime",
"\u3000": " " "\u3000": " "
} }

File diff suppressed because it is too large Load Diff

@ -1,29 +1,29 @@
class KanjiBlock(str): class KanjiBlock(str):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
obj = str.__new__(cls, "@") obj = str.__new__(cls, "@")
kanji = args[0] kanji = args[0]
kanji_dict = args[1] kanji_dict = args[1]
obj.kanji = kanji obj.kanji = kanji
if len(kanji) == 1: if len(kanji) == 1:
obj.romaji = " " + kanji_dict["romaji"] obj.romaji = " " + kanji_dict["romaji"]
else: else:
if "verb stem" in kanji_dict["w_type"]: if "verb stem" in kanji_dict["w_type"]:
obj.romaji = " " + kanji_dict["romaji"] obj.romaji = " " + kanji_dict["romaji"]
else: else:
obj.romaji = " " + kanji_dict["romaji"] + " " obj.romaji = " " + kanji_dict["romaji"] + " "
if "other_readings" in kanji_dict: if "other_readings" in kanji_dict:
obj.w_type = [kanji_dict["w_type"]] obj.w_type = [kanji_dict["w_type"]]
obj.w_type.extend( obj.w_type.extend(
[k for k in kanji_dict["other_readings"].keys()] [k for k in kanji_dict["other_readings"].keys()]
) )
else: else:
obj.w_type = kanji_dict["w_type"] obj.w_type = kanji_dict["w_type"]
return obj return obj
def __repr__(self): def __repr__(self):
return self.kanji.encode("unicode_escape") return self.kanji.encode("unicode_escape")
def __str__(self): def __str__(self):
return self.romaji.encode("utf-8") return self.romaji.encode("utf-8")

@ -1,6 +1,6 @@
class Particle(str): class Particle(str):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
particle_str = args[0] particle_str = args[0]
obj = str.__new__(cls, " " + particle_str + " ") obj = str.__new__(cls, " " + particle_str + " ")
obj.pname = particle_str obj.pname = particle_str
return obj return obj

@ -1,4 +1,4 @@
# noinspection PyClassHasNoInit # noinspection PyClassHasNoInit
class UnicodeRomajiMapping: # caching class UnicodeRomajiMapping: # caching
kana_mapping = {} kana_mapping = {}
kanji_mapping = {} kanji_mapping = {}

@ -1,5 +1,5 @@
from .UnicodeRomajiMapping import UnicodeRomajiMapping from .UnicodeRomajiMapping import UnicodeRomajiMapping
from .KanjiBlock import KanjiBlock from .KanjiBlock import KanjiBlock
from .Particle import Particle from .Particle import Particle
__all__ = ["UnicodeRomajiMapping", "KanjiBlock", "Particle"] __all__ = ["UnicodeRomajiMapping", "KanjiBlock", "Particle"]

@ -0,0 +1,5 @@
import asyncio
from noawait import NoAwaitPool
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -1,51 +1,65 @@
import asyncio from local import loop, noawait
from typing import TypedDict
from aiohttp import web from aiohttp import web
import asyncpg import config
import config import api.route
import api.route import utils.web
import utils.web from service.database import DatabaseService
from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi
from service.mediawiki_api import MediaWikiApi
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker # Auto create Table
from api.model.base import BaseModel
from service.tiktoken import TikTokenService from api.model.toolkit_ui.conversation import ConversationModel as _
from api.model.chat_complete.conversation import ConversationChunkModel as _
async def index(request: web.Request): from api.model.embedding_search.title_collection import TitleCollectionModel as _
return utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request) from api.model.embedding_search.title_index import TitleIndexModel as _
async def init_mw_api(app: web.Application): from service.tiktoken import TikTokenService
mw_api = MediaWikiApi.create()
if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD: async def index(request: web.Request):
await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD) return utils.web.api_response(1, data={"message": "Isekai toolkit API"}, request=request)
site_meta = await mw_api.get_site_meta() async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create()
print("Connected to Wiki %s, Robot username: %s" % (site_meta["sitename"], site_meta["user"])) if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD:
try:
async def init_database(app: web.Application): await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD)
dbs = await DatabaseService.create(app) except Exception as e:
print("Database connected.") print("Cannot login to Robot account, please check config.")
async def init_tiktoken(app: web.Application): site_meta = await mw_api.get_site_meta()
await TikTokenService.create()
print("Tiktoken model loaded.") print("Connected to Wiki %s, Robot username: %s" % (site_meta["sitename"], site_meta["user"]))
if __name__ == '__main__': async def init_database(app: web.Application):
loop = asyncio.get_event_loop() dbs = await DatabaseService.create(app)
print("Database connected.")
app = web.Application()
async with dbs.engine.begin() as conn:
if config.DATABASE: await conn.run_sync(BaseModel.metadata.create_all)
app.on_startup.append(init_database)
async def init_tiktoken(app: web.Application):
if config.MW_API: await TikTokenService.create()
app.on_startup.append(init_mw_api) print("Tiktoken model loaded.")
if config.OPENAI_TOKEN: async def stop_noawait_pool(app: web.Application):
app.on_startup.append(init_tiktoken) await noawait.end()
app.router.add_route('*', '/', index) if __name__ == '__main__':
api.route.init(app) app = web.Application()
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop)
if config.DATABASE:
app.on_startup.append(init_database)
if config.MW_API:
app.on_startup.append(init_mw_api)
if config.OPENAI_TOKEN:
app.on_startup.append(init_tiktoken)
app.on_shutdown.append(stop_noawait_pool)
app.router.add_route('*', '/', index)
api.route.init(app)
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop)

@ -0,0 +1,72 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
from functools import wraps
import sys
import traceback
from typing import Callable, Coroutine
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.loop = loop
self.running = True
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
async def end(self):
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
coroutine = f(*args, **kwargs)
self.add_task(coroutine)
return decorated_function
async def _finish_task(self, task: Task):
try:
if not task.done():
task.cancel()
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
if handler_ret is Coroutine:
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
should_remove = []
for task in self.task_list:
if task.done():
await self._finish_task(task)
should_remove.append(task)
for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)

@ -1,16 +1,17 @@
aiohttp==3.8.4 aiohttp==3.8.4
jieba==0.42.1 jieba==0.42.1
pypinyin==0.37.0 pypinyin==0.37.0
simplejson==3.17.0 simplejson==3.17.0
beautifulsoup4==4.11.2 beautifulsoup4==4.11.2
markdownify==0.11.6 markdownify==0.11.6
asyncpg==0.27.0 asyncpg==0.27.0
aiofiles==23.1.0 aiofiles==23.1.0
pgvector==0.1.6 pgvector==0.1.6
websockets==11.0 websockets==11.0
PyJWT==2.6.0 PyJWT==2.6.0
asyncpg-stubs==0.27.0 asyncpg-stubs==0.27.0
sqlalchemy==2.0.9 sqlalchemy==2.0.9
aiohttp-sse-client2==0.3.0 aiohttp-sse-client2==0.3.0
OpenCC==1.1.6 OpenCC==1.1.6
event-emitter-asyncio==1.0.4 event-emitter-asyncio==1.0.4
tiktoken-async==0.3.2

@ -19,6 +19,11 @@ from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
class ChatCompleteServiceResponse(TypedDict): class ChatCompleteServiceResponse(TypedDict):
message: str message: str
message_tokens: int message_tokens: int
@ -44,9 +49,18 @@ class ChatCompleteService:
self.tiktoken: TikTokenService = None self.tiktoken: TikTokenService = None
self.extract_doc: list = None
self.mwapi = MediaWikiApi.create() self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create() self.openai_api = OpenAIApi.create()
self.user_id = 0
self.question = ""
self.question_tokens: Optional[int] = None
self.conversation_id: Optional[int] = None
self.delta_data = {}
async def __aenter__(self): async def __aenter__(self):
self.tiktoken = await TikTokenService.create() self.tiktoken = await TikTokenService.create()
@ -67,26 +81,55 @@ class ChatCompleteService:
async def get_question_tokens(self, question: str): async def get_question_tokens(self, question: str):
return await self.tiktoken.get_tokens(question) return await self.tiktoken.get_tokens(question)
async def chat_complete(self, question: str, on_message: Optional[callable] = None, on_extracted_doc: Optional[callable] = None, async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None,
conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None, question_tokens: Optional[int] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse: embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse:
if user_id is not None: if user_id is not None:
user_id = int(user_id) user_id = int(user_id)
self.user_id = user_id
self.question = question
self.conversation_info = None self.conversation_info = None
if conversation_id is not None: if conversation_id is not None:
conversation_id = int(conversation_id) self.conversation_id = int(conversation_id)
self.conversation_info = await self.conversation_helper.get_conversation(conversation_id) self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id)
else:
self.conversation_id = None
if self.conversation_info is not None:
if self.conversation_info.user_id != user_id:
raise web.HTTPUnauthorized()
if question_tokens is None:
self.question_tokens = await self.get_question_tokens(question)
else:
self.question_tokens = question_tokens
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
# Extract document from wiki page index
self.extract_doc = None
if embedding_search is not None:
self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
if self.extract_doc is not None:
self.question_tokens += token_usage
return ChatCompleteServicePrepareResponse(
extract_doc=self.extract_doc,
question_tokens=self.question_tokens
)
async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse:
delta_data = {} delta_data = {}
self.conversation_chunk = None self.conversation_chunk = None
message_log = [] message_log = []
if self.conversation_info is not None: if self.conversation_info is not None:
if self.conversation_info.user_id != user_id: self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id)
raise web.HTTPUnauthorized()
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(conversation_id)
# If the conversation is too long, we need to make a summary # If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
@ -95,9 +138,9 @@ class ChatCompleteService:
{"role": "summary", "content": summary, "tokens": tokens} {"role": "summary", "content": summary, "tokens": tokens}
] ]
self.conversation_chunk = await self.conversation_chunk_helper.add(conversation_id, new_message_log, tokens) self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens)
delta_data["conversation_chunk_id"] = self.conversation_chunk.id self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id
message_log = [] message_log = []
for message in self.conversation_chunk.message_data: for message in self.conversation_chunk.message_data:
@ -106,40 +149,26 @@ class ChatCompleteService:
"content": message["content"], "content": message["content"],
}) })
if question_tokens is None: if self.extract_doc is not None:
question_tokens = await self.get_question_tokens(question) doc_prompt_content = "\n".join(["%d. %s" % (
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)])
question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
extract_doc = None
if embedding_search is not None:
extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
if extract_doc is not None:
if on_extracted_doc is not None:
await on_extracted_doc(extract_doc)
question_tokens = token_usage
doc_prompt_content = "\n".join(["%d. %s" % (
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(extract_doc)])
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", { doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", {
"content": doc_prompt_content}) "content": doc_prompt_content})
message_log.append({"role": "user", "content": doc_prompt}) message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system_prompt") system_prompt = utils.config.get_prompt("chat", "system_prompt")
# Start chat complete # Start chat complete
if on_message is not None: if on_message is not None:
response = await self.openai_api.chat_complete_stream(question, system_prompt, message_log, on_message) response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message)
else: else:
response = await self.openai_api.chat_complete(question, system_prompt, message_log) response = await self.openai_api.chat_complete(self.question, system_prompt, message_log)
if self.conversation_info is None: if self.conversation_info is None:
# Create a new conversation # Create a new conversation
message_log_list = [ message_log_list = [
{"role": "user", "content": question, "tokens": question_tokens}, {"role": "user", "content": self.question, "tokens": self.question_tokens},
{"role": "assistant", {"role": "assistant",
"content": response["message"], "tokens": response["message_tokens"]}, "content": response["message"], "tokens": response["message_tokens"]},
] ]
@ -152,21 +181,21 @@ class ChatCompleteService:
print(str(e), file=sys.stderr) print(str(e), file=sys.stderr)
traceback.print_exc(file=sys.stderr) traceback.print_exc(file=sys.stderr)
total_token_usage = question_tokens + response["message_tokens"] total_token_usage = self.question_tokens + response["message_tokens"]
title_info = self.embedding_search.title_info title_info = self.embedding_search.title_info
self.conversation_info = await self.conversation_helper.add(user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title) self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage) self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage)
else: else:
# Update the conversation chunk # Update the conversation chunk
await self.conversation_helper.refresh_updated_at(conversation_id) await self.conversation_helper.refresh_updated_at(self.conversation_id)
self.conversation_chunk.message_data.append( self.conversation_chunk.message_data.append(
{"role": "user", "content": question, "tokens": question_tokens}) {"role": "user", "content": self.question, "tokens": self.question_tokens})
self.conversation_chunk.message_data.append( self.conversation_chunk.message_data.append(
{"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]}) {"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]})
flag_modified(self.conversation_chunk, "message_data") flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens += question_tokens + \ self.conversation_chunk.tokens += self.question_tokens + \
response["message_tokens"] response["message_tokens"]
await self.conversation_chunk_helper.update(self.conversation_chunk) await self.conversation_chunk_helper.update(self.conversation_chunk)

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio import local
from urllib.parse import quote_plus from urllib.parse import quote_plus
from aiohttp import web from aiohttp import web
import asyncpg import asyncpg
@ -38,7 +38,7 @@ class DatabaseService:
self.create_session: async_sessionmaker[AsyncSession] = None self.create_session: async_sessionmaker[AsyncSession] = None
async def init(self): async def init(self):
loop = asyncio.get_event_loop() loop = local.loop
self.pool = asyncpg.create_pool(**config.DATABASE, loop=loop) self.pool = asyncpg.create_pool(**config.DATABASE, loop=loop)
await self.pool.__aenter__() await self.pool.__aenter__()

@ -129,10 +129,24 @@ class EmbeddingSearchService:
if self.unindexed_docs is None: if self.unindexed_docs is None:
return False return False
chunk_limit = 500
chunk_len = 0
doc_chunk = []
total_token_usage = 0 total_token_usage = 0
processed_len = 0
async def on_embedding_progress(current, length):
nonlocal processed_len
indexed_docs = processed_len + current
if on_progress is not None:
await on_progress(indexed_docs, len(self.unindexed_docs))
async def embedding_doc(doc_chunk): async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk, on_embedding_progress)
await self.page_index.index_doc(doc_chunk) await self.page_index.index_doc(doc_chunk)
return token_usage return token_usage
@ -140,12 +154,7 @@ class EmbeddingSearchService:
if len(self.unindexed_docs) > 0: if len(self.unindexed_docs) > 0:
if on_progress is not None: if on_progress is not None:
await on_progress(0, len(self.unindexed_docs)) await on_progress(0, len(self.unindexed_docs))
chunk_limit = 500
chunk_len = 0
processed_len = 0
doc_chunk = []
for doc in self.unindexed_docs: for doc in self.unindexed_docs:
chunk_len += len(doc) chunk_len += len(doc)

@ -1,7 +1,7 @@
import json import json
import sys import sys
import time import time
from typing import Optional from typing import Optional, TypedDict
import aiohttp import aiohttp
import config import config
@ -18,6 +18,10 @@ class MediaWikiApiException(Exception):
class MediaWikiPageNotFoundException(MediaWikiApiException): class MediaWikiPageNotFoundException(MediaWikiApiException):
pass pass
class ChatCompleteReportUsageResponse(TypedDict):
point_cost: int
transaction_id: str
class MediaWikiApi: class MediaWikiApi:
cookie_jar = aiohttp.CookieJar(unsafe=True) cookie_jar = aiohttp.CookieJar(unsafe=True)
@ -27,7 +31,7 @@ class MediaWikiApi:
def __init__(self, api_url: str): def __init__(self, api_url: str):
self.api_url = api_url self.api_url = api_url
self.login_time = 0 self.login_time = 0.0
self.login_identity = None self.login_identity = None
async def get_page_info(self, title: str): async def get_page_info(self, title: str):
@ -142,7 +146,7 @@ class MediaWikiApi:
async def refresh_login(self): async def refresh_login(self):
if self.login_identity is None: if self.login_identity is None:
return False return False
if time.time() - self.login_time > 10: if time.time() - self.login_time > 30:
return await self.robot_login(self.login_identity["username"], self.login_identity["password"]) return await self.robot_login(self.login_identity["username"], self.login_identity["password"])
async def chat_complete_user_info(self, user_id: int): async def chat_complete_user_info(self, user_id: int):
@ -166,7 +170,7 @@ class MediaWikiApi:
return data["chatcompletebot"]["userinfo"] return data["chatcompletebot"]["userinfo"]
async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> str: async def chat_complete_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse:
await self.refresh_login() await self.refresh_login()
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
@ -189,7 +193,8 @@ class MediaWikiApi:
print(data) print(data)
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
return data["chatcompletebot"]["reportusage"]["transactionid"] return ChatCompleteReportUsageResponse(point_cost=int(data["chatcompletebot"]["reportusage"]["pointcost"]),
transaction_id=data["chatcompletebot"]["reportusage"]["transactionid"])
async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None): async def chat_complete_end_transaction(self, transaction_id: str, tokens: Optional[int] = None):
await self.refresh_login() await self.refresh_login()

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import TypedDict from typing import Callable, Optional, TypedDict
import aiohttp import aiohttp
import config import config
@ -23,38 +23,117 @@ class ChatCompleteResponse(TypedDict):
class OpenAIApi: class OpenAIApi:
@staticmethod @staticmethod
def create(): def create():
return OpenAIApi(config.OPENAI_API or "https://api.openai.com", config.OPENAI_TOKEN) return OpenAIApi()
def __init__(self, api_url: str, token: str): def __init__(self):
self.api_url = api_url if config.OPENAI_API_TYPE == "azure":
self.token = token self.api_url = config.AZURE_OPENAI_ENDPOINT
self.api_key = config.AZURE_OPENAI_KEY
else:
self.api_url = config.OPENAI_API or "https://api.openai.com"
self.api_key = config.OPENAI_TOKEN
def build_header(self):
if config.OPENAI_API_TYPE == "azure":
return {
"Content-Type": "application/json",
"Accept": "application/json",
"api-key": self.api_key
}
else:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure":
if method == "completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method
elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method
else:
return self.api_url + "/v1/" + method
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
text_list = []
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
text_list.append(text)
async def get_embeddings(self, doc_list: list):
token_usage = 0 token_usage = 0
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
text_list = [doc["text"] for doc in doc_list] url = self.get_url("embeddings")
params = { params = {}
"model": "text-embedding-ada-002", post_data = {
"input": text_list, "input": text_list,
} }
async with session.post(self.api_url + "/v1/embeddings",
headers={"Authorization": f"Bearer {self.token}"},
json=params,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
for one_data in data["data"]: if config.OPENAI_API_TYPE == "azure":
embedding = one_data["embedding"] params["api-version"] = "2023-05-15"
index = one_data["index"] else:
post_data["model"] = "text-embedding-ada-002"
if index < len(doc_list):
if embedding is not None:
embedding = np.array(embedding) if config.OPENAI_API_TYPE == "azure":
doc_list[index]["embedding"] = embedding # Azure api does not support batch
for index, text in enumerate(text_list):
token_usage = int(data["usage"]["total_tokens"]) async with session.post(url,
headers=self.build_header(),
params=params,
json={"input": text},
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
one_data = data["data"]
if len(one_data) > 0:
embedding = one_data[0]["embedding"]
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage += int(data["usage"]["total_tokens"])
if on_index_progress is not None:
await on_index_progress(index, len(text_list))
else:
async with session.post(url,
headers=self.build_header(),
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
for one_data in data["data"]:
embedding = one_data["embedding"]
index = one_data["index"]
if index < len(doc_list):
if embedding is not None:
embedding = np.array(embedding)
doc_list[index]["embedding"] = embedding
token_usage = int(data["usage"]["total_tokens"])
await on_index_progress(index, len(text_list))
return (doc_list, token_usage) return (doc_list, token_usage)
@ -79,17 +158,26 @@ class OpenAIApi:
async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None): async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation) messageList = await self.make_message_list(question, system_prompt, conversation)
params = { url = self.get_url("completions")
"model": "gpt-3.5-turbo",
params = {}
post_data = {
"messages": messageList, "messages": messageList,
"user": user, "user": user,
} }
params = {k: v for k, v in params.items() if v is not None}
if config.OPENAI_API_TYPE == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
post_data = {k: v for k, v in post_data.items() if v is not None}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(self.api_url + "/v1/chat/completions", async with session.post(url,
headers={"Authorization": f"Bearer {self.token}"}, headers=self.build_header,
json=params, params=params,
json=post_data,
timeout=30, timeout=30,
proxy=config.REQUEST_PROXY) as resp: proxy=config.REQUEST_PROXY) as resp:
@ -138,7 +226,7 @@ class OpenAIApi:
option={ option={
"method": "POST" "method": "POST"
}, },
headers={"Authorization": f"Bearer {self.token}"}, headers={"Authorization": f"Bearer {self.api_key}"},
json=params, json=params,
proxy=config.REQUEST_PROXY proxy=config.REQUEST_PROXY
) as session: ) as session:

@ -1,33 +1,37 @@
import asyncio import local
import config from service.database import DatabaseService
import asyncpg
from service.database import DatabaseService from service.embedding_search import EmbeddingSearchService
from service.embedding_search import EmbeddingSearchService async def main():
dbs = await DatabaseService.create()
async def main():
dbs = await DatabaseService.create() async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search:
await embedding_search.prepare_update_index()
async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search:
async def on_index_progress(current, length): async def on_index_progress(current, length):
print("索引进度:%.1f%%" % (current / length * 100)) print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True)
await embedding_search.update_page_index(on_index_progress) print("")
await embedding_search.update_page_index(on_index_progress)
while True: print("")
query = input("请输入要搜索的问题 (.exit 退出)")
if query == ".exit": while True:
break query = input("请输入要搜索的问题 (.exit 退出)")
res = await embedding_search.search(query, 5) if query == ".exit":
total_length = 0 break
if res: res, token_usage = await embedding_search.search(query, 5)
for one in res: total_length = 0
total_length += len(one["markdown"]) if res:
print("%s, distance=%.4f" % (one["markdown"], one["distance"])) for one in res:
else: total_length += len(one["markdown"])
print("未搜索到相关内容") print("%s, distance=%.4f" % (one["markdown"], one["distance"]))
else:
print("总长度:%d" % total_length) print("未搜索到相关内容")
if __name__ == '__main__': print("总长度:%d" % total_length)
asyncio.run(main())
await local.noawait.end()
if __name__ == '__main__':
local.loop.run_until_complete(main())
Loading…
Cancel
Save