You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
|
import asyncio
|
|
|
|
|
import config
|
|
|
|
|
import asyncpg
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
|
|
|
|
|
|
from service.embedding_search import EmbeddingSearchService
|
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
dbs = await DatabaseService.create()
|
|
|
|
|
|
|
|
|
|
async with EmbeddingSearchService(dbs, "代号:曙光的世界/黄昏的阿瓦隆") as embedding_search:
|
|
|
|
|
async def on_index_progress(current, length):
|
|
|
|
|
print("索引进度:%.1f%%" % (current / length * 100))
|
|
|
|
|
|
|
|
|
|
await embedding_search.update_page_index(on_index_progress)
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
query = input("请输入要搜索的问题 (.exit 退出):")
|
|
|
|
|
if query == ".exit":
|
|
|
|
|
break
|
|
|
|
|
res = await embedding_search.search(query, 5)
|
|
|
|
|
total_length = 0
|
|
|
|
|
if res:
|
|
|
|
|
for one in res:
|
|
|
|
|
total_length += len(one["markdown"])
|
|
|
|
|
print("%s, distance=%.4f" % (one["markdown"], one["distance"]))
|
|
|
|
|
else:
|
|
|
|
|
print("未搜索到相关内容")
|
|
|
|
|
|
|
|
|
|
print("总长度:%d" % total_length)
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
asyncio.run(main())
|