修改ChatComplete事件的实现

master
落雨楓 2 weeks ago
parent 83bb6ad213
commit cc283b1618

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any
from agentkit.context import ConversationContext
class BaseAgent:
def __init__(self, config: dict[str, Any]):
self.config = config
async def __aenter__(self) -> BaseAgent:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
async def count_tokens(self, text: str) -> int:
return len(text.split())
async def chat_complete(self, context: ConversationContext, stream: bool = False):
pass

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Any
from agentkit.base.agent import BaseAgent
class BaseAgentFactory:
def __init__(self, props: dict):
pass
async def __aenter__(self) -> BaseAgentFactory:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
async def new_from_id(self, agent_id: str) -> BaseAgent | None:
"""
Create agent from agent id.
"""
return None
async def new_from_config(self, agent_type: str, config: dict) -> BaseAgent | None:
"""
Create agent from agent type and config.
"""
return None

@ -7,6 +7,7 @@ from agentkit.types import PluginConfigParam
class AgentKitModelApi:
CONFIG: list[PluginConfigParam] = []
INPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
pass

@ -5,6 +5,7 @@ from agentkit.types import PluginConfigParam, OnProgressCallback
class AgentKitPostprocessor:
CONFIG: list[PluginConfigParam] = []
INPUT: list[PluginConfigParam] = []
OUTPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
@ -22,13 +23,13 @@ class AgentKitPostprocessor:
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None
async def on_after_completion(self, context: ConversationContext) -> Optional[dict]:
async def on_after_completion(self, context: ConversationContext, **kwargs) -> Optional[dict]:
"""
Called after the conversation, before the completion is returned to the user.
"""
return None
async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]:
async def on_after_completion_background(self, context: ConversationContext, **kwargs) -> Optional[dict]:
"""
Called after the conversation, after the completion is returned to the user.
"""

@ -5,6 +5,7 @@ from agentkit.types import PluginConfigParam, OnProgressCallback
class AgentKitPreprocessor:
CONFIG: list[PluginConfigParam] = []
INPUT: list[PluginConfigParam] = []
OUTPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
@ -22,5 +23,5 @@ class AgentKitPreprocessor:
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None
async def on_before_completion(self, context: ConversationContext) -> Optional[dict]:
async def on_before_completion(self, context: ConversationContext, **kwargs) -> Optional[dict]:
return None

@ -38,7 +38,7 @@ class ConversationContext(ConversationData):
self.system_prompt = ""
"""System prompt"""
self.session_data: dict = {}
self.state: dict = {}
"""Temporary data for plugins"""
self.completion: Optional[str] = None
@ -52,39 +52,39 @@ class ConversationContext(ConversationData):
return [msg for msg in self.history if msg["role"] == role]
def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext:
def create_sub_context(self, copy_store = False, copy_state = False) -> ConversationContext:
sub_context = ConversationContext()
sub_context.user_id = self.user_id
if copy_store:
sub_context.store = self.store.copy()
if copy_session_data:
sub_context.session_data = self.session_data.copy()
if copy_state:
sub_context.state = self.state.copy()
return sub_context
def __getitem__(self, key):
return self.session_data[key]
return self.state[key]
def __setitem__(self, key, value):
self.session_data[key] = value
self.state[key] = value
return value
def __delitem__(self, key):
del self.session_data[key]
del self.state[key]
def __contains__(self, key):
return key in self.session_data
return key in self.state
def __iter__(self):
return iter(self.session_data)
return iter(self.state)
def __len__(self):
return len(self.session_data)
return len(self.state)

@ -0,0 +1,126 @@
import logging
from agentkit.context import ConversationContext
from agentkit.types import AgentKitFlowStep
class BreakLoopInterrupt(Exception):
pass
class AgentKitFlowExecutor:
def __init__(self, flow_script: list[AgentKitFlowStep]):
self.flow = flow_script
def _execute_step(self, step: AgentKitFlowStep, context: ConversationContext):
"""
Execute a single flow step.
Args:
step (dict): The step to execute.
Returns:
Any: The result of the step execution.
"""
step_type = step["type"]
if step_type == "call":
return self._execute_call(step, context)
elif step_type == "if_else":
return self._execute_if_else(step, context)
elif step_type == "loop":
return self._execute_loop(step)
elif step_type == "break_loop":
return self._execute_break_loop(step)
else:
raise ValueError(f"Unsupported step type: {step_type}")
def _execute_call(self, step: dict, context: ConversationContext):
"""
Execute a 'call' step.
Args:
step (dict): The call step configuration.
Returns:
Any: The result of the call.
"""
func_id = step.get("id")
config = {k: self._resolve_value(v) for k, v in step.get("config", {}).items()}
output_map = step.get("output_map", {})
# Simulate calling a function (replace with actual function calls if needed)
logging.info(f"Calling function {func_id} with config: {config}")
result = {key: f"mocked_value_{key}" for key in output_map.keys()} # Mocked output
# Map outputs to variables
for key, var_name in output_map.items():
self.variables[var_name] = result.get(key)
return result
def _execute_if_else(self, step: dict, context: ConversationContext):
"""
Execute an 'if_else' step.
Args:
step (dict): The if_else step configuration.
Returns:
Any: The result of the executed branch.
"""
condition = step.get("condition")
condition_input = {k: self._resolve_value(v) for k, v in step.get("condition_input", {}).items()}
# Evaluate the condition
condition_result = eval(condition, {}, condition_input)
# Execute the appropriate branch
branch = step["true_branch"] if condition_result else step["false_branch"]
for sub_step in branch:
self._execute_step(sub_step, context)
def _execute_loop(self, step: dict, context: ConversationContext):
"""
Execute a 'loop' step.
Args:
step (dict): The loop step configuration.
Returns:
Any: The result of the loop execution.
"""
loop_num = step.get("loop_num")
index_var = step.get("index_var")
loop_body = step.get("loop_body", [])
for i in range(loop_num):
if index_var:
context.state[index_var] = i
for sub_step in loop_body:
self._execute_step(sub_step, context)
def _execute_break_loop(self):
"""
Execute a 'break_loop' step.
Args:
step (dict): The break_loop step configuration.
Returns:
None
"""
raise BreakLoopInterrupt()
def execute(self, context: ConversationContext):
"""
Execute the entire flow.
Returns:
None
"""
try:
for step in self.flow:
self._execute_step(step)
except BreakLoopInterrupt:
pass

@ -49,7 +49,46 @@ class LLMFunctionResponse(TypedDict):
"""Directly return the response to the user, without further processing."""
OnProgressCallback = Callable[[int, int], Any]
class AgentKitFlowStep(TypedDict):
type: str
comment: str
class AgentKitFlowCall(AgentKitFlowStep):
# type: "call"
id: str
comment: str
config: dict[str, Any]
output_map: dict[str, str]
class AgentKitFlowIfElse(AgentKitFlowStep):
# type: "if_else"
condition: str
condition_input: dict[str, Any]
true_branch: list
false_branch: list
class AgentKitLoop(AgentKitFlowStep):
# type: "loop"
loop_num: int
loop_body: list
index_var: Optional[str]
class AgentKitBreakLoop(AgentKitFlowStep):
# type: "break_loop"
pass
class AgentKitSetVar(AgentKitFlowStep):
# type: "set_var"
var_name: str
var_value_expr: str
OnProgressCallback = Callable[[str, int, int], Any]
OnTaskStartCallback = Callable[[str, str], Any]
OnTaskProgressCallback = Callable[[int, Optional[int], Optional[str]], Any]

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Callable
from type_defs.chat_complete_task import ChatCompleteServiceResponse
from utils.program import run_listeners
class ChatCompleteEvent:
def __init__(self):
self.on_tool_running: list[Callable[[str, str], None]] = []
self.on_tool_output: list[Callable[[str], None]] = []
self.on_message_output: list[Callable[[str], None]] = []
async def emit_tool_running(self, tool_name: str, running_state: str = "") -> None:
await run_listeners(self.on_tool_running, tool_name, running_state)
async def emit_tool_output(self, output: str) -> None:
await run_listeners(self.on_tool_output, output)
async def emit_message_output(self, output: str) -> None:
await run_listeners(self.on_message_output, output)
class ChatCompleteTaskEvent(ChatCompleteEvent):
def __init__(self):
super().__init__()
self.on_finished: list[Callable[[str | None], None]] = []
self.on_error: list[Callable[[str], None]] = []
async def emit_finished(self, result: ChatCompleteServiceResponse | None) -> None:
await run_listeners(self.on_finished, result)
async def emit_error(self, ex: Exception) -> None:
await run_listeners(self.on_error, ex)

@ -1,3 +0,0 @@
text2vec>=1.2.9
--index-url https://download.pytorch.org/whl/cpu
torch

@ -10,4 +10,5 @@ PyJWT==2.6.0
aiohttp-sse-client2==0.3.0
OpenCC==1.1.1
event-emitter-asyncio==1.0.4
toml==0.10.2
toml==0.10.2
ollama==0.4.4

@ -7,10 +7,11 @@ from server.controller.task.ChatCompleteTask import ChatCompleteTask
from server.model.base import clone_model
from server.model.chat_complete.bot_persona import BotPersonaHelper
from server.model.toolbox_ui.conversation import ConversationHelper
from type_defs.chat_complete_task import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse
from utils.local import noawait
from aiohttp import web
from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage
from service.chat_complete import calculate_point_usage
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService
@ -248,6 +249,8 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def get_point_usage(request: web.Request):
"""计算ChatComplete所需积分"""
params = await utils.web.get_param(request, {
"question": {
"type": str,
@ -299,6 +302,8 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def get_persona_list(request: web.Request):
"""获取机器人列表"""
params = await utils.web.get_param(request, {
"page": {
"type": int,
@ -336,6 +341,8 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def get_persona_info(request: web.Request):
"""获取机器人信息"""
params = await utils.web.get_param(request, {
"id": {
"type": int,
@ -371,6 +378,8 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def start_chat_complete(request: web.Request):
"""开始ChatComplete"""
params = await utils.web.get_param(request, {
"title": {
"type": str,
@ -471,6 +480,7 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def chat_complete_stream(request: web.Request):
"""流式输出ChatComplete结果"""
if not utils.web.is_websocket(request):
return await utils.web.api_response(-1, error={
"code": "websocket-required",
@ -542,9 +552,28 @@ class ChatComplete:
await ws.close()
else:
async def on_closed():
task.on_message.remove(on_message)
task.on_finished.remove(on_finished)
task.on_error.remove(on_error)
task.events.on_tool_running.remove(on_tool_running)
task.events.on_tool_output.remove(on_tool_output)
task.events.on_message_output.remove(on_message)
task.events.on_finished.remove(on_finished)
task.events.on_error.remove(on_error)
async def on_tool_running(tool_name: str, running_state: str):
try:
await ws.send_json({
'event': 'tool_run',
'status': 1,
'tool_name': tool_name,
'running_state': running_state,
})
except ConnectionResetError:
await on_closed()
async def on_tool_output(output: str):
try:
await ws.send_str(">" + output)
except ConnectionResetError:
await on_closed()
async def on_message(delta_message: str):
try:
@ -552,15 +581,18 @@ class ChatComplete:
except ConnectionResetError:
await on_closed()
async def on_finished(result: ChatCompleteServiceResponse):
async def on_finished(result: ChatCompleteServiceResponse | None):
try:
ignored_keys = ["message"]
response_result = {
"point_usage": task.point_usage,
}
for k, v in result.items():
if k not in ignored_keys:
response_result[k] = v
if result:
for k, v in result.items():
if k not in ignored_keys:
response_result[k] = v
await ws.send_json({
'event': 'finished',
'status': 1,
@ -587,9 +619,11 @@ class ChatComplete:
except ConnectionResetError:
await on_closed()
task.on_message.append(on_message)
task.on_finished.append(on_finished)
task.on_error.append(on_error)
task.events.on_tool_running.append(on_tool_running)
task.events.on_tool_output.append(on_tool_output)
task.events.on_message_output.append(on_message)
task.events.on_finished.append(on_finished)
task.events.on_error.append(on_error)
# Send received message
await ws.send_json({

@ -2,14 +2,14 @@ from __future__ import annotations
import sys
import time
import traceback
from events.chat_complete_event import ChatCompleteEvent, ChatCompleteTaskEvent
from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper
from type_defs.chat_complete_task import ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
from utils.local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import (
ChatCompleteService,
ChatCompleteServicePrepareResponse,
ChatCompleteServiceResponse,
calculate_point_usage,
)
from service.database import DatabaseService
@ -30,9 +30,7 @@ class ChatCompleteTask:
self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = []
self.events: ChatCompleteTaskEvent = ChatCompleteTaskEvent()
self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService
@ -111,45 +109,10 @@ class ChatCompleteTask:
await self.end()
raise e
async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message:
try:
await callback(delta_message)
except Exception as e:
print(
"Error while processing on_message callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print(
"Error while processing on_finished callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error:
try:
await callback(err)
except Exception as e:
print(
"Error while processing on_error callback: %s" % e, file=sys.stderr
)
traceback.print_exc()
async def run(self) -> ChatCompleteServiceResponse:
chat_complete_tasks[self.task_id] = self
try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
chat_res = await self.chat_complete.finish_chat_complete(event_container=self.events)
await self.chat_complete.set_latest_point_usage(self.point_usage)
@ -160,7 +123,7 @@ class ChatCompleteTask:
self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分
)
await self._on_finished()
await self.events.emit_finished(chat_res)
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
@ -172,7 +135,7 @@ class ChatCompleteTask:
self.transatcion_id, error=err_msg
)
await self._on_error(e)
await self.events.emit_error(e)
finally:
await self.end()

@ -4,6 +4,7 @@ import time
import traceback
from typing import Optional, Tuple, TypedDict
from events.chat_complete_event import ChatCompleteEvent
from server.model.chat_complete.bot_persona import BotPersonaHelper
from server.model.chat_complete.conversation import (
ConversationChunkHelper,
@ -13,6 +14,7 @@ import sys
from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel
from libs.config import Config
from type_defs.chat_complete_task import ChatCompleteQuestionTooLongException, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
import utils.config, utils.web
from aiohttp import web
@ -24,27 +26,6 @@ from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException
from service.tiktoken import TikTokenService
class ChatCompleteQuestionTooLongException(Exception):
def __init__(self, tokens_limit: int, tokens_current: int):
super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
self.tokens_limit = tokens_limit
self.tokens_current = tokens_current
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
conversation_id: int
chunk_id: int
api_id: str
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
total_tokens: int
finish_reason: str
question_message_id: str
response_message_id: str
delta_data: dict
class ChatCompleteService:
def __init__(self, dbs: DatabaseService, title: str):
@ -285,7 +266,7 @@ class ChatCompleteService:
)
async def finish_chat_complete(
self, on_message: Optional[callable] = None
self, event_container: ChatCompleteEvent = None
) -> ChatCompleteServiceResponse:
delta_data = {}
@ -311,7 +292,6 @@ class ChatCompleteService:
doc_prompt = utils.config.get_prompt(
"extracted_doc", "prompt", {"content": doc_prompt_content}
)
message_log.append({"role": "user", "content": doc_prompt})
system_prompt = self.chat_system_prompt
if system_prompt is None:
@ -323,13 +303,15 @@ class ChatCompleteService:
system_prompt = utils.config.format_prompt(system_prompt)
# Start chat complete
if on_message is not None:
if event_container is not None:
response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, self.model, message_log, on_message
self.question, system_prompt, self.model, message_log,
doc_prompt=doc_prompt, event_container=event_container
)
else:
response = await self.openai_api.chat_complete(
self.question, system_prompt, self.model, message_log
self.question, system_prompt, self.model, message_log,
doc_prompt=doc_prompt
)
description = response["message"][0:150]

@ -298,15 +298,13 @@ class EmbeddingSearchService:
query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"]
print(query_embedding)
if query_embedding is None:
return [], token_usage
res = await self.page_index.search_text_embedding(
query_embedding, in_collection, limit, self.page_id
)
print(res)
if res:
filtered = []
for one in res:

@ -9,6 +9,7 @@ import numpy as np
from aiohttp_sse_client2 import client as sse_client
from service.tiktoken import TikTokenService
from events.chat_complete_event import ChatCompleteEvent
AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview"
AZURE_EMBEDDING_API_VERSION = "2023-05-15"
@ -164,7 +165,7 @@ class OpenAIApi:
return (doc_list, token_usage)
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None) -> list[ChatCompleteMessageLog]:
summaryContent = None
messageList: list[ChatCompleteMessageLog] = []
@ -177,13 +178,17 @@ class OpenAIApi:
if summaryContent:
system_prompt += "\n\n" + summaryContent
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
messageList.insert(0, ChatCompleteMessageLog(role="system", content=system_prompt))
if doc_prompt:
question = doc_prompt + "\n\n" + question
messageList.append(ChatCompleteMessageLog(role="user", content=question))
return messageList
async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
messageList = await self.make_message_list(question, system_prompt, conversation)
async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None, user = None):
messageList = await self.make_message_list(question, system_prompt, conversation, doc_prompt)
url = self.get_url("chat/completions")
@ -231,10 +236,11 @@ class OpenAIApi:
return None
async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None,
event_container: ChatCompleteEvent = None, user = None):
tiktoken = await TikTokenService.create()
messageList = await self.make_message_list(question, system_prompt, conversation)
messageList = await self.make_message_list(question, system_prompt, conversation, doc_prompt)
prompt_tokens = 0
for message in messageList:
@ -264,6 +270,9 @@ class OpenAIApi:
res_message: list[str] = []
finish_reason = None
# 模型是否在深度思考中
in_reasoning = False
async with sse_client.EventSource(
url,
option={
@ -296,9 +305,9 @@ class OpenAIApi:
finish_reason = choice["finish_reason"]
delta_content = choice["delta"]
if "content" in delta_content:
if "content" in delta_content and delta_content["content"] is not None:
delta_message: str = delta_content["content"]
# Skip empty lines before content
if not content_started:
if delta_message.replace("\n", "") == "":
@ -311,9 +320,24 @@ class OpenAIApi:
# if config.DEBUG:
# print(delta_message, end="", flush=True)
if on_message is not None:
await on_message(delta_message)
await event_container.emit_message_output(delta_message)
elif "reasoning_content" in delta_content and delta_content["reasoning_content"] is not None:
# 处理DeepSeek的深度思考内容
delta_message: str = delta_content["reasoning_content"]
if not in_reasoning:
await event_container.emit_tool_running("深度思考")
in_reasoning = True
# Skip empty lines before content
if not content_started:
if delta_message.replace("\n", "") == "":
continue
else:
content_started = True
await event_container.emit_tool_output(delta_message)
if finish_reason is not None:
break

@ -6,26 +6,24 @@ import asyncio
import random
import threading
from typing import Callable, Optional, TypedDict
import torch
from text2vec import SentenceModel
import ollama
from utils.local import loop
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
LOCAL_EMBEDDING_QUEUE_TIMEOUT = 1
class Text2VecEmbeddingQueueTaskInfo(TypedDict):
class OllamaEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
embedding: list
class Text2VecEmbeddingQueue:
class OllamaEmbeddingQueue:
def __init__(self, model: str) -> None:
self.model_name = model
self.embedding_model: SentenceModel = None
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
self.task_map: dict[int, OllamaEmbeddingQueueTaskInfo] = {}
self.task_list: list[OllamaEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
@ -33,7 +31,7 @@ class Text2VecEmbeddingQueue:
def post_init(self):
self.embedding_model = SentenceModel(self.model_name)
pass
async def get_embeddings(self, text: str):
@ -81,13 +79,13 @@ class Text2VecEmbeddingQueue:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model.encode([task["text"]])
res = ollama.embeddings(model=self.model_name, prompt=task["text"])
with self.lock:
task["embedding"] = embeddings[0]
task["embedding"] = res.embedding
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
elif last_task_time is not None and current_time > last_task_time + LOCAL_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
@ -107,7 +105,7 @@ class TextEmbeddingService:
def __init__(self):
self.tiktoken: TikTokenService = None
self.text2vec_queue: Text2VecEmbeddingQueue = None
self.ollama_queue: OllamaEmbeddingQueue = None
self.openai_api: OpenAIApi = None
@staticmethod
@ -120,22 +118,22 @@ class TextEmbeddingService:
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_type = Config.get("embedding.type", "text2vec")
self.embedding_type = Config.get("embedding.type", "ollama")
if self.embedding_type == "text2vec":
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model)
if self.embedding_type == "ollama":
embedding_model = Config.get("embedding.embedding_model", "shaw/dmeta-embedding-zh")
self.ollama_queue = OllamaEmbeddingQueue(model=embedding_model)
await loop.run_in_executor(None, self.ollama_queue.post_init)
elif self.embedding_type == "openai":
api_id = Config.get("embedding.api_id")
self.openai_api: OpenAIApi = await OpenAIApi.create(api_id)
await loop.run_in_executor(None, self.text2vec_queue.post_init)
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
async def get_ollama_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
total_token_usage = 0
for index, doc in enumerate(doc_list):
text = doc["text"]
embedding = await self.text2vec_queue.get_embeddings(text)
embedding = await self.ollama_queue.get_embeddings(text)
doc["embedding"] = embedding
total_token_usage += await self.tiktoken.get_tokens(text)
@ -165,7 +163,7 @@ class TextEmbeddingService:
text = " ".join(new_lines)
doc["text"] = text
if self.embedding_type == "text2vec":
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress)
if self.embedding_type == "ollama":
return await self.get_ollama_embeddings(res_doc_list, on_index_progress)
elif self.embedding_type == "openai":
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)

@ -16,7 +16,7 @@ class TikTokenService:
self.enc: tiktoken_async.Encoding = None
async def init(self):
self.enc = await tiktoken_async.encoding_for_model("gpt-3.5-turbo")
self.enc = await tiktoken_async.encoding_for_model("gpt-4o")
async def get_tokens(self, text: str):
encoded = self.enc.encode(text)

@ -2,12 +2,12 @@ import asyncio
import time
import base as _
from utils.local import loop, noawait
from service.text_embedding import Text2VecEmbeddingQueue
from service.text_embedding import OllamaEmbeddingQueue
async def main():
embedding_list = []
queue = []
text2vec_queue = Text2VecEmbeddingQueue("shibing624/text2vec-base-chinese")
text2vec_queue = OllamaEmbeddingQueue("shibing624/text2vec-base-chinese")
start_time = time.time()

@ -0,0 +1,25 @@
from __future__ import annotations
from typing_extensions import TypedDict
class ChatCompleteQuestionTooLongException(Exception):
def __init__(self, tokens_limit: int, tokens_current: int):
super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
self.tokens_limit = tokens_limit
self.tokens_current = tokens_current
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
conversation_id: int
chunk_id: int
api_id: str
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
total_tokens: int
finish_reason: str
question_message_id: str
response_message_id: str
delta_data: dict

@ -0,0 +1,19 @@
import asyncio
import sys
import traceback
from typing import Coroutine
async def run_listeners(listeners: list, *args, **kwargs) -> Coroutine[None, None, None]:
for listener in listeners:
try:
res = listener(*args, **kwargs)
if asyncio.iscoroutine(res):
await res
except Exception as ex:
print(
"Error while processing callback: %s" % ex,
file=sys.stderr,
)
traceback.print_exc()
return None
Loading…
Cancel
Save