修改ChatComplete事件的实现

master
落雨楓 2 weeks ago
parent 83bb6ad213
commit cc283b1618

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any
from agentkit.context import ConversationContext
class BaseAgent:
def __init__(self, config: dict[str, Any]):
self.config = config
async def __aenter__(self) -> BaseAgent:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
async def count_tokens(self, text: str) -> int:
return len(text.split())
async def chat_complete(self, context: ConversationContext, stream: bool = False):
pass

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Any
from agentkit.base.agent import BaseAgent
class BaseAgentFactory:
def __init__(self, props: dict):
pass
async def __aenter__(self) -> BaseAgentFactory:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
async def new_from_id(self, agent_id: str) -> BaseAgent | None:
"""
Create agent from agent id.
"""
return None
async def new_from_config(self, agent_type: str, config: dict) -> BaseAgent | None:
"""
Create agent from agent type and config.
"""
return None

@ -7,6 +7,7 @@ from agentkit.types import PluginConfigParam
class AgentKitModelApi: class AgentKitModelApi:
CONFIG: list[PluginConfigParam] = [] CONFIG: list[PluginConfigParam] = []
INPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs): def __init__(self, props: dict, **kwargs):
pass pass

@ -5,6 +5,7 @@ from agentkit.types import PluginConfigParam, OnProgressCallback
class AgentKitPostprocessor: class AgentKitPostprocessor:
CONFIG: list[PluginConfigParam] = [] CONFIG: list[PluginConfigParam] = []
INPUT: list[PluginConfigParam] = []
OUTPUT: list[PluginConfigParam] = [] OUTPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs): def __init__(self, props: dict, **kwargs):
@ -22,13 +23,13 @@ class AgentKitPostprocessor:
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None return None
async def on_after_completion(self, context: ConversationContext) -> Optional[dict]: async def on_after_completion(self, context: ConversationContext, **kwargs) -> Optional[dict]:
""" """
Called after the conversation, before the completion is returned to the user. Called after the conversation, before the completion is returned to the user.
""" """
return None return None
async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]: async def on_after_completion_background(self, context: ConversationContext, **kwargs) -> Optional[dict]:
""" """
Called after the conversation, after the completion is returned to the user. Called after the conversation, after the completion is returned to the user.
""" """

@ -5,6 +5,7 @@ from agentkit.types import PluginConfigParam, OnProgressCallback
class AgentKitPreprocessor: class AgentKitPreprocessor:
CONFIG: list[PluginConfigParam] = [] CONFIG: list[PluginConfigParam] = []
INPUT: list[PluginConfigParam] = []
OUTPUT: list[PluginConfigParam] = [] OUTPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs): def __init__(self, props: dict, **kwargs):
@ -22,5 +23,5 @@ class AgentKitPreprocessor:
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None return None
async def on_before_completion(self, context: ConversationContext) -> Optional[dict]: async def on_before_completion(self, context: ConversationContext, **kwargs) -> Optional[dict]:
return None return None

@ -38,7 +38,7 @@ class ConversationContext(ConversationData):
self.system_prompt = "" self.system_prompt = ""
"""System prompt""" """System prompt"""
self.session_data: dict = {} self.state: dict = {}
"""Temporary data for plugins""" """Temporary data for plugins"""
self.completion: Optional[str] = None self.completion: Optional[str] = None
@ -52,39 +52,39 @@ class ConversationContext(ConversationData):
return [msg for msg in self.history if msg["role"] == role] return [msg for msg in self.history if msg["role"] == role]
def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext: def create_sub_context(self, copy_store = False, copy_state = False) -> ConversationContext:
sub_context = ConversationContext() sub_context = ConversationContext()
sub_context.user_id = self.user_id sub_context.user_id = self.user_id
if copy_store: if copy_store:
sub_context.store = self.store.copy() sub_context.store = self.store.copy()
if copy_session_data: if copy_state:
sub_context.session_data = self.session_data.copy() sub_context.state = self.state.copy()
return sub_context return sub_context
def __getitem__(self, key): def __getitem__(self, key):
return self.session_data[key] return self.state[key]
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.session_data[key] = value self.state[key] = value
return value return value
def __delitem__(self, key): def __delitem__(self, key):
del self.session_data[key] del self.state[key]
def __contains__(self, key): def __contains__(self, key):
return key in self.session_data return key in self.state
def __iter__(self): def __iter__(self):
return iter(self.session_data) return iter(self.state)
def __len__(self): def __len__(self):
return len(self.session_data) return len(self.state)

@ -0,0 +1,126 @@
import logging
from agentkit.context import ConversationContext
from agentkit.types import AgentKitFlowStep
class BreakLoopInterrupt(Exception):
pass
class AgentKitFlowExecutor:
def __init__(self, flow_script: list[AgentKitFlowStep]):
self.flow = flow_script
def _execute_step(self, step: AgentKitFlowStep, context: ConversationContext):
"""
Execute a single flow step.
Args:
step (dict): The step to execute.
Returns:
Any: The result of the step execution.
"""
step_type = step["type"]
if step_type == "call":
return self._execute_call(step, context)
elif step_type == "if_else":
return self._execute_if_else(step, context)
elif step_type == "loop":
return self._execute_loop(step)
elif step_type == "break_loop":
return self._execute_break_loop(step)
else:
raise ValueError(f"Unsupported step type: {step_type}")
def _execute_call(self, step: dict, context: ConversationContext):
"""
Execute a 'call' step.
Args:
step (dict): The call step configuration.
Returns:
Any: The result of the call.
"""
func_id = step.get("id")
config = {k: self._resolve_value(v) for k, v in step.get("config", {}).items()}
output_map = step.get("output_map", {})
# Simulate calling a function (replace with actual function calls if needed)
logging.info(f"Calling function {func_id} with config: {config}")
result = {key: f"mocked_value_{key}" for key in output_map.keys()} # Mocked output
# Map outputs to variables
for key, var_name in output_map.items():
self.variables[var_name] = result.get(key)
return result
def _execute_if_else(self, step: dict, context: ConversationContext):
"""
Execute an 'if_else' step.
Args:
step (dict): The if_else step configuration.
Returns:
Any: The result of the executed branch.
"""
condition = step.get("condition")
condition_input = {k: self._resolve_value(v) for k, v in step.get("condition_input", {}).items()}
# Evaluate the condition
condition_result = eval(condition, {}, condition_input)
# Execute the appropriate branch
branch = step["true_branch"] if condition_result else step["false_branch"]
for sub_step in branch:
self._execute_step(sub_step, context)
def _execute_loop(self, step: dict, context: ConversationContext):
"""
Execute a 'loop' step.
Args:
step (dict): The loop step configuration.
Returns:
Any: The result of the loop execution.
"""
loop_num = step.get("loop_num")
index_var = step.get("index_var")
loop_body = step.get("loop_body", [])
for i in range(loop_num):
if index_var:
context.state[index_var] = i
for sub_step in loop_body:
self._execute_step(sub_step, context)
def _execute_break_loop(self):
"""
Execute a 'break_loop' step.
Args:
step (dict): The break_loop step configuration.
Returns:
None
"""
raise BreakLoopInterrupt()
def execute(self, context: ConversationContext):
"""
Execute the entire flow.
Returns:
None
"""
try:
for step in self.flow:
self._execute_step(step)
except BreakLoopInterrupt:
pass

@ -49,7 +49,46 @@ class LLMFunctionResponse(TypedDict):
"""Directly return the response to the user, without further processing.""" """Directly return the response to the user, without further processing."""
OnProgressCallback = Callable[[int, int], Any] class AgentKitFlowStep(TypedDict):
type: str
comment: str
class AgentKitFlowCall(AgentKitFlowStep):
# type: "call"
id: str
comment: str
config: dict[str, Any]
output_map: dict[str, str]
class AgentKitFlowIfElse(AgentKitFlowStep):
# type: "if_else"
condition: str
condition_input: dict[str, Any]
true_branch: list
false_branch: list
class AgentKitLoop(AgentKitFlowStep):
# type: "loop"
loop_num: int
loop_body: list
index_var: Optional[str]
class AgentKitBreakLoop(AgentKitFlowStep):
# type: "break_loop"
pass
class AgentKitSetVar(AgentKitFlowStep):
# type: "set_var"
var_name: str
var_value_expr: str
OnProgressCallback = Callable[[str, int, int], Any]
OnTaskStartCallback = Callable[[str, str], Any] OnTaskStartCallback = Callable[[str, str], Any]
OnTaskProgressCallback = Callable[[int, Optional[int], Optional[str]], Any] OnTaskProgressCallback = Callable[[int, Optional[int], Optional[str]], Any]

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Callable
from type_defs.chat_complete_task import ChatCompleteServiceResponse
from utils.program import run_listeners
class ChatCompleteEvent:
def __init__(self):
self.on_tool_running: list[Callable[[str, str], None]] = []
self.on_tool_output: list[Callable[[str], None]] = []
self.on_message_output: list[Callable[[str], None]] = []
async def emit_tool_running(self, tool_name: str, running_state: str = "") -> None:
await run_listeners(self.on_tool_running, tool_name, running_state)
async def emit_tool_output(self, output: str) -> None:
await run_listeners(self.on_tool_output, output)
async def emit_message_output(self, output: str) -> None:
await run_listeners(self.on_message_output, output)
class ChatCompleteTaskEvent(ChatCompleteEvent):
def __init__(self):
super().__init__()
self.on_finished: list[Callable[[str | None], None]] = []
self.on_error: list[Callable[[str], None]] = []
async def emit_finished(self, result: ChatCompleteServiceResponse | None) -> None:
await run_listeners(self.on_finished, result)
async def emit_error(self, ex: Exception) -> None:
await run_listeners(self.on_error, ex)

@ -1,3 +0,0 @@
text2vec>=1.2.9
--index-url https://download.pytorch.org/whl/cpu
torch

@ -11,3 +11,4 @@ aiohttp-sse-client2==0.3.0
OpenCC==1.1.1 OpenCC==1.1.1
event-emitter-asyncio==1.0.4 event-emitter-asyncio==1.0.4
toml==0.10.2 toml==0.10.2
ollama==0.4.4

@ -7,10 +7,11 @@ from server.controller.task.ChatCompleteTask import ChatCompleteTask
from server.model.base import clone_model from server.model.base import clone_model
from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.chat_complete.bot_persona import BotPersonaHelper
from server.model.toolbox_ui.conversation import ConversationHelper from server.model.toolbox_ui.conversation import ConversationHelper
from type_defs.chat_complete_task import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse
from utils.local import noawait from utils.local import noawait
from aiohttp import web from aiohttp import web
from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage from service.chat_complete import calculate_point_usage
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
@ -248,6 +249,8 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_point_usage(request: web.Request): async def get_point_usage(request: web.Request):
"""计算ChatComplete所需积分"""
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"question": { "question": {
"type": str, "type": str,
@ -299,6 +302,8 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_persona_list(request: web.Request): async def get_persona_list(request: web.Request):
"""获取机器人列表"""
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"page": { "page": {
"type": int, "type": int,
@ -336,6 +341,8 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def get_persona_info(request: web.Request): async def get_persona_info(request: web.Request):
"""获取机器人信息"""
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"id": { "id": {
"type": int, "type": int,
@ -371,6 +378,8 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def start_chat_complete(request: web.Request): async def start_chat_complete(request: web.Request):
"""开始ChatComplete"""
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"title": { "title": {
"type": str, "type": str,
@ -471,6 +480,7 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def chat_complete_stream(request: web.Request): async def chat_complete_stream(request: web.Request):
"""流式输出ChatComplete结果"""
if not utils.web.is_websocket(request): if not utils.web.is_websocket(request):
return await utils.web.api_response(-1, error={ return await utils.web.api_response(-1, error={
"code": "websocket-required", "code": "websocket-required",
@ -542,9 +552,28 @@ class ChatComplete:
await ws.close() await ws.close()
else: else:
async def on_closed(): async def on_closed():
task.on_message.remove(on_message) task.events.on_tool_running.remove(on_tool_running)
task.on_finished.remove(on_finished) task.events.on_tool_output.remove(on_tool_output)
task.on_error.remove(on_error) task.events.on_message_output.remove(on_message)
task.events.on_finished.remove(on_finished)
task.events.on_error.remove(on_error)
async def on_tool_running(tool_name: str, running_state: str):
try:
await ws.send_json({
'event': 'tool_run',
'status': 1,
'tool_name': tool_name,
'running_state': running_state,
})
except ConnectionResetError:
await on_closed()
async def on_tool_output(output: str):
try:
await ws.send_str(">" + output)
except ConnectionResetError:
await on_closed()
async def on_message(delta_message: str): async def on_message(delta_message: str):
try: try:
@ -552,15 +581,18 @@ class ChatComplete:
except ConnectionResetError: except ConnectionResetError:
await on_closed() await on_closed()
async def on_finished(result: ChatCompleteServiceResponse): async def on_finished(result: ChatCompleteServiceResponse | None):
try: try:
ignored_keys = ["message"] ignored_keys = ["message"]
response_result = { response_result = {
"point_usage": task.point_usage, "point_usage": task.point_usage,
} }
for k, v in result.items():
if k not in ignored_keys: if result:
response_result[k] = v for k, v in result.items():
if k not in ignored_keys:
response_result[k] = v
await ws.send_json({ await ws.send_json({
'event': 'finished', 'event': 'finished',
'status': 1, 'status': 1,
@ -587,9 +619,11 @@ class ChatComplete:
except ConnectionResetError: except ConnectionResetError:
await on_closed() await on_closed()
task.on_message.append(on_message) task.events.on_tool_running.append(on_tool_running)
task.on_finished.append(on_finished) task.events.on_tool_output.append(on_tool_output)
task.on_error.append(on_error) task.events.on_message_output.append(on_message)
task.events.on_finished.append(on_finished)
task.events.on_error.append(on_error)
# Send received message # Send received message
await ws.send_json({ await ws.send_json({

@ -2,14 +2,14 @@ from __future__ import annotations
import sys import sys
import time import time
import traceback import traceback
from events.chat_complete_event import ChatCompleteEvent, ChatCompleteTaskEvent
from libs.config import Config from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.chat_complete.bot_persona import BotPersonaHelper
from type_defs.chat_complete_task import ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
from utils.local import noawait from utils.local import noawait
from typing import Optional, Callable, Union from typing import Optional, Callable, Union
from service.chat_complete import ( from service.chat_complete import (
ChatCompleteService, ChatCompleteService,
ChatCompleteServicePrepareResponse,
ChatCompleteServiceResponse,
calculate_point_usage, calculate_point_usage,
) )
from service.database import DatabaseService from service.database import DatabaseService
@ -30,9 +30,7 @@ class ChatCompleteTask:
self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
): ):
self.task_id = utils.web.generate_uuid() self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = [] self.events: ChatCompleteTaskEvent = ChatCompleteTaskEvent()
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = [] self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService self.chat_complete_service: ChatCompleteService
@ -111,45 +109,10 @@ class ChatCompleteTask:
await self.end() await self.end()
raise e raise e
async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message:
try:
await callback(delta_message)
except Exception as e:
print(
"Error while processing on_message callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print(
"Error while processing on_finished callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error:
try:
await callback(err)
except Exception as e:
print(
"Error while processing on_error callback: %s" % e, file=sys.stderr
)
traceback.print_exc()
async def run(self) -> ChatCompleteServiceResponse: async def run(self) -> ChatCompleteServiceResponse:
chat_complete_tasks[self.task_id] = self chat_complete_tasks[self.task_id] = self
try: try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message) chat_res = await self.chat_complete.finish_chat_complete(event_container=self.events)
await self.chat_complete.set_latest_point_usage(self.point_usage) await self.chat_complete.set_latest_point_usage(self.point_usage)
@ -160,7 +123,7 @@ class ChatCompleteTask:
self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分 self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分
) )
await self._on_finished() await self.events.emit_finished(chat_res)
except Exception as e: except Exception as e:
err_msg = f"Error while processing chat complete request: {e}" err_msg = f"Error while processing chat complete request: {e}"
@ -172,7 +135,7 @@ class ChatCompleteTask:
self.transatcion_id, error=err_msg self.transatcion_id, error=err_msg
) )
await self._on_error(e) await self.events.emit_error(e)
finally: finally:
await self.end() await self.end()

@ -4,6 +4,7 @@ import time
import traceback import traceback
from typing import Optional, Tuple, TypedDict from typing import Optional, Tuple, TypedDict
from events.chat_complete_event import ChatCompleteEvent
from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.chat_complete.bot_persona import BotPersonaHelper
from server.model.chat_complete.conversation import ( from server.model.chat_complete.conversation import (
ConversationChunkHelper, ConversationChunkHelper,
@ -13,6 +14,7 @@ import sys
from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel
from libs.config import Config from libs.config import Config
from type_defs.chat_complete_task import ChatCompleteQuestionTooLongException, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
import utils.config, utils.web import utils.config, utils.web
from aiohttp import web from aiohttp import web
@ -24,27 +26,6 @@ from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
class ChatCompleteQuestionTooLongException(Exception):
def __init__(self, tokens_limit: int, tokens_current: int):
super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
self.tokens_limit = tokens_limit
self.tokens_current = tokens_current
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
conversation_id: int
chunk_id: int
api_id: str
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
total_tokens: int
finish_reason: str
question_message_id: str
response_message_id: str
delta_data: dict
class ChatCompleteService: class ChatCompleteService:
def __init__(self, dbs: DatabaseService, title: str): def __init__(self, dbs: DatabaseService, title: str):
@ -285,7 +266,7 @@ class ChatCompleteService:
) )
async def finish_chat_complete( async def finish_chat_complete(
self, on_message: Optional[callable] = None self, event_container: ChatCompleteEvent = None
) -> ChatCompleteServiceResponse: ) -> ChatCompleteServiceResponse:
delta_data = {} delta_data = {}
@ -311,7 +292,6 @@ class ChatCompleteService:
doc_prompt = utils.config.get_prompt( doc_prompt = utils.config.get_prompt(
"extracted_doc", "prompt", {"content": doc_prompt_content} "extracted_doc", "prompt", {"content": doc_prompt_content}
) )
message_log.append({"role": "user", "content": doc_prompt})
system_prompt = self.chat_system_prompt system_prompt = self.chat_system_prompt
if system_prompt is None: if system_prompt is None:
@ -323,13 +303,15 @@ class ChatCompleteService:
system_prompt = utils.config.format_prompt(system_prompt) system_prompt = utils.config.format_prompt(system_prompt)
# Start chat complete # Start chat complete
if on_message is not None: if event_container is not None:
response = await self.openai_api.chat_complete_stream( response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, self.model, message_log, on_message self.question, system_prompt, self.model, message_log,
doc_prompt=doc_prompt, event_container=event_container
) )
else: else:
response = await self.openai_api.chat_complete( response = await self.openai_api.chat_complete(
self.question, system_prompt, self.model, message_log self.question, system_prompt, self.model, message_log,
doc_prompt=doc_prompt
) )
description = response["message"][0:150] description = response["message"][0:150]

@ -298,15 +298,13 @@ class EmbeddingSearchService:
query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc) query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"] query_embedding = query_doc[0]["embedding"]
print(query_embedding)
if query_embedding is None: if query_embedding is None:
return [], token_usage return [], token_usage
res = await self.page_index.search_text_embedding( res = await self.page_index.search_text_embedding(
query_embedding, in_collection, limit, self.page_id query_embedding, in_collection, limit, self.page_id
) )
print(res)
if res: if res:
filtered = [] filtered = []
for one in res: for one in res:

@ -9,6 +9,7 @@ import numpy as np
from aiohttp_sse_client2 import client as sse_client from aiohttp_sse_client2 import client as sse_client
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
from events.chat_complete_event import ChatCompleteEvent
AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview" AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview"
AZURE_EMBEDDING_API_VERSION = "2023-05-15" AZURE_EMBEDDING_API_VERSION = "2023-05-15"
@ -164,7 +165,7 @@ class OpenAIApi:
return (doc_list, token_usage) return (doc_list, token_usage)
async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]: async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None) -> list[ChatCompleteMessageLog]:
summaryContent = None summaryContent = None
messageList: list[ChatCompleteMessageLog] = [] messageList: list[ChatCompleteMessageLog] = []
@ -177,13 +178,17 @@ class OpenAIApi:
if summaryContent: if summaryContent:
system_prompt += "\n\n" + summaryContent system_prompt += "\n\n" + summaryContent
messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt)) messageList.insert(0, ChatCompleteMessageLog(role="system", content=system_prompt))
if doc_prompt:
question = doc_prompt + "\n\n" + question
messageList.append(ChatCompleteMessageLog(role="user", content=question)) messageList.append(ChatCompleteMessageLog(role="user", content=question))
return messageList return messageList
async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None): async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None, user = None):
messageList = await self.make_message_list(question, system_prompt, conversation) messageList = await self.make_message_list(question, system_prompt, conversation, doc_prompt)
url = self.get_url("chat/completions") url = self.get_url("chat/completions")
@ -231,10 +236,11 @@ class OpenAIApi:
return None return None
async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None): async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None,
event_container: ChatCompleteEvent = None, user = None):
tiktoken = await TikTokenService.create() tiktoken = await TikTokenService.create()
messageList = await self.make_message_list(question, system_prompt, conversation) messageList = await self.make_message_list(question, system_prompt, conversation, doc_prompt)
prompt_tokens = 0 prompt_tokens = 0
for message in messageList: for message in messageList:
@ -264,6 +270,9 @@ class OpenAIApi:
res_message: list[str] = [] res_message: list[str] = []
finish_reason = None finish_reason = None
# 模型是否在深度思考中
in_reasoning = False
async with sse_client.EventSource( async with sse_client.EventSource(
url, url,
option={ option={
@ -296,7 +305,7 @@ class OpenAIApi:
finish_reason = choice["finish_reason"] finish_reason = choice["finish_reason"]
delta_content = choice["delta"] delta_content = choice["delta"]
if "content" in delta_content: if "content" in delta_content and delta_content["content"] is not None:
delta_message: str = delta_content["content"] delta_message: str = delta_content["content"]
# Skip empty lines before content # Skip empty lines before content
@ -311,8 +320,23 @@ class OpenAIApi:
# if config.DEBUG: # if config.DEBUG:
# print(delta_message, end="", flush=True) # print(delta_message, end="", flush=True)
if on_message is not None: await event_container.emit_message_output(delta_message)
await on_message(delta_message) elif "reasoning_content" in delta_content and delta_content["reasoning_content"] is not None:
# 处理DeepSeek的深度思考内容
delta_message: str = delta_content["reasoning_content"]
if not in_reasoning:
await event_container.emit_tool_running("深度思考")
in_reasoning = True
# Skip empty lines before content
if not content_started:
if delta_message.replace("\n", "") == "":
continue
else:
content_started = True
await event_container.emit_tool_output(delta_message)
if finish_reason is not None: if finish_reason is not None:
break break

@ -6,26 +6,24 @@ import asyncio
import random import random
import threading import threading
from typing import Callable, Optional, TypedDict from typing import Callable, Optional, TypedDict
import torch import ollama
from text2vec import SentenceModel
from utils.local import loop from utils.local import loop
from service.openai_api import OpenAIApi from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1 LOCAL_EMBEDDING_QUEUE_TIMEOUT = 1
class Text2VecEmbeddingQueueTaskInfo(TypedDict): class OllamaEmbeddingQueueTaskInfo(TypedDict):
task_id: int task_id: int
text: str text: str
embedding: torch.Tensor embedding: list
class Text2VecEmbeddingQueue: class OllamaEmbeddingQueue:
def __init__(self, model: str) -> None: def __init__(self, model: str) -> None:
self.model_name = model self.model_name = model
self.embedding_model: SentenceModel = None
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {} self.task_map: dict[int, OllamaEmbeddingQueueTaskInfo] = {}
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = [] self.task_list: list[OllamaEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock() self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None self.thread: Optional[threading.Thread] = None
@ -33,7 +31,7 @@ class Text2VecEmbeddingQueue:
def post_init(self): def post_init(self):
self.embedding_model = SentenceModel(self.model_name) pass
async def get_embeddings(self, text: str): async def get_embeddings(self, text: str):
@ -81,13 +79,13 @@ class Text2VecEmbeddingQueue:
task = self.task_list.pop(0) task = self.task_list.pop(0)
if task is not None: if task is not None:
embeddings = self.embedding_model.encode([task["text"]]) res = ollama.embeddings(model=self.model_name, prompt=task["text"])
with self.lock: with self.lock:
task["embedding"] = embeddings[0] task["embedding"] = res.embedding
last_task_time = time.time() last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: elif last_task_time is not None and current_time > last_task_time + LOCAL_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None self.thread = None
self.running = False self.running = False
running = False running = False
@ -107,7 +105,7 @@ class TextEmbeddingService:
def __init__(self): def __init__(self):
self.tiktoken: TikTokenService = None self.tiktoken: TikTokenService = None
self.text2vec_queue: Text2VecEmbeddingQueue = None self.ollama_queue: OllamaEmbeddingQueue = None
self.openai_api: OpenAIApi = None self.openai_api: OpenAIApi = None
@staticmethod @staticmethod
@ -120,22 +118,22 @@ class TextEmbeddingService:
async def init(self): async def init(self):
self.tiktoken = await TikTokenService.create() self.tiktoken = await TikTokenService.create()
self.embedding_type = Config.get("embedding.type", "text2vec") self.embedding_type = Config.get("embedding.type", "ollama")
if self.embedding_type == "text2vec": if self.embedding_type == "ollama":
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese") embedding_model = Config.get("embedding.embedding_model", "shaw/dmeta-embedding-zh")
self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model) self.ollama_queue = OllamaEmbeddingQueue(model=embedding_model)
await loop.run_in_executor(None, self.ollama_queue.post_init)
elif self.embedding_type == "openai": elif self.embedding_type == "openai":
api_id = Config.get("embedding.api_id") api_id = Config.get("embedding.api_id")
self.openai_api: OpenAIApi = await OpenAIApi.create(api_id) self.openai_api: OpenAIApi = await OpenAIApi.create(api_id)
await loop.run_in_executor(None, self.text2vec_queue.post_init)
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): async def get_ollama_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
total_token_usage = 0 total_token_usage = 0
for index, doc in enumerate(doc_list): for index, doc in enumerate(doc_list):
text = doc["text"] text = doc["text"]
embedding = await self.text2vec_queue.get_embeddings(text) embedding = await self.ollama_queue.get_embeddings(text)
doc["embedding"] = embedding doc["embedding"] = embedding
total_token_usage += await self.tiktoken.get_tokens(text) total_token_usage += await self.tiktoken.get_tokens(text)
@ -165,7 +163,7 @@ class TextEmbeddingService:
text = " ".join(new_lines) text = " ".join(new_lines)
doc["text"] = text doc["text"] = text
if self.embedding_type == "text2vec": if self.embedding_type == "ollama":
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress) return await self.get_ollama_embeddings(res_doc_list, on_index_progress)
elif self.embedding_type == "openai": elif self.embedding_type == "openai":
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress) return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)

@ -16,7 +16,7 @@ class TikTokenService:
self.enc: tiktoken_async.Encoding = None self.enc: tiktoken_async.Encoding = None
async def init(self): async def init(self):
self.enc = await tiktoken_async.encoding_for_model("gpt-3.5-turbo") self.enc = await tiktoken_async.encoding_for_model("gpt-4o")
async def get_tokens(self, text: str): async def get_tokens(self, text: str):
encoded = self.enc.encode(text) encoded = self.enc.encode(text)

@ -2,12 +2,12 @@ import asyncio
import time import time
import base as _ import base as _
from utils.local import loop, noawait from utils.local import loop, noawait
from service.text_embedding import Text2VecEmbeddingQueue from service.text_embedding import OllamaEmbeddingQueue
async def main(): async def main():
embedding_list = [] embedding_list = []
queue = [] queue = []
text2vec_queue = Text2VecEmbeddingQueue("shibing624/text2vec-base-chinese") text2vec_queue = OllamaEmbeddingQueue("shibing624/text2vec-base-chinese")
start_time = time.time() start_time = time.time()

@ -0,0 +1,25 @@
from __future__ import annotations
from typing_extensions import TypedDict
class ChatCompleteQuestionTooLongException(Exception):
def __init__(self, tokens_limit: int, tokens_current: int):
super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
self.tokens_limit = tokens_limit
self.tokens_current = tokens_current
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
conversation_id: int
chunk_id: int
api_id: str
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
total_tokens: int
finish_reason: str
question_message_id: str
response_message_id: str
delta_data: dict

@ -0,0 +1,19 @@
import asyncio
import sys
import traceback
from typing import Coroutine
async def run_listeners(listeners: list, *args, **kwargs) -> Coroutine[None, None, None]:
for listener in listeners:
try:
res = listener(*args, **kwargs)
if asyncio.iscoroutine(res):
await res
except Exception as ex:
print(
"Error while processing callback: %s" % ex,
file=sys.stderr,
)
traceback.print_exc()
return None
Loading…
Cancel
Save