diff --git a/agentkit/base/agent.py b/agentkit/base/agent.py new file mode 100644 index 0000000..1618573 --- /dev/null +++ b/agentkit/base/agent.py @@ -0,0 +1,21 @@ +from __future__ import annotations +from typing import Any + +from agentkit.context import ConversationContext + + +class BaseAgent: + def __init__(self, config: dict[str, Any]): + self.config = config + + async def __aenter__(self) -> BaseAgent: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + return None + + async def count_tokens(self, text: str) -> int: + return len(text.split()) + + async def chat_complete(self, context: ConversationContext, stream: bool = False): + pass \ No newline at end of file diff --git a/agentkit/base/agent_factory.py b/agentkit/base/agent_factory.py new file mode 100644 index 0000000..a57c238 --- /dev/null +++ b/agentkit/base/agent_factory.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import Any + +from agentkit.base.agent import BaseAgent + + +class BaseAgentFactory: + def __init__(self, props: dict): + pass + + + async def __aenter__(self) -> BaseAgentFactory: + return self + + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + return None + + + async def new_from_id(self, agent_id: str) -> BaseAgent | None: + """ + Create agent from agent id. + """ + return None + + + async def new_from_config(self, agent_type: str, config: dict) -> BaseAgent | None: + """ + Create agent from agent type and config. + """ + return None \ No newline at end of file diff --git a/agentkit/base/modelapi.py b/agentkit/base/modelapi.py index 6036a09..03619ab 100644 --- a/agentkit/base/modelapi.py +++ b/agentkit/base/modelapi.py @@ -7,6 +7,7 @@ from agentkit.types import PluginConfigParam class AgentKitModelApi: CONFIG: list[PluginConfigParam] = [] + INPUT: list[PluginConfigParam] = [] def __init__(self, props: dict, **kwargs): pass diff --git a/agentkit/base/postprocessor.py b/agentkit/base/postprocessor.py index 3faf75f..4f87806 100644 --- a/agentkit/base/postprocessor.py +++ b/agentkit/base/postprocessor.py @@ -5,6 +5,7 @@ from agentkit.types import PluginConfigParam, OnProgressCallback class AgentKitPostprocessor: CONFIG: list[PluginConfigParam] = [] + INPUT: list[PluginConfigParam] = [] OUTPUT: list[PluginConfigParam] = [] def __init__(self, props: dict, **kwargs): @@ -22,13 +23,13 @@ class AgentKitPostprocessor: async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: return None - async def on_after_completion(self, context: ConversationContext) -> Optional[dict]: + async def on_after_completion(self, context: ConversationContext, **kwargs) -> Optional[dict]: """ Called after the conversation, before the completion is returned to the user. """ return None - async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]: + async def on_after_completion_background(self, context: ConversationContext, **kwargs) -> Optional[dict]: """ Called after the conversation, after the completion is returned to the user. """ diff --git a/agentkit/base/preprocessor.py b/agentkit/base/preprocessor.py index 571761c..eb0b5e8 100644 --- a/agentkit/base/preprocessor.py +++ b/agentkit/base/preprocessor.py @@ -5,6 +5,7 @@ from agentkit.types import PluginConfigParam, OnProgressCallback class AgentKitPreprocessor: CONFIG: list[PluginConfigParam] = [] + INPUT: list[PluginConfigParam] = [] OUTPUT: list[PluginConfigParam] = [] def __init__(self, props: dict, **kwargs): @@ -22,5 +23,5 @@ class AgentKitPreprocessor: async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: return None - async def on_before_completion(self, context: ConversationContext) -> Optional[dict]: + async def on_before_completion(self, context: ConversationContext, **kwargs) -> Optional[dict]: return None \ No newline at end of file diff --git a/agentkit/context.py b/agentkit/context.py index 585dbaa..39e637f 100644 --- a/agentkit/context.py +++ b/agentkit/context.py @@ -38,7 +38,7 @@ class ConversationContext(ConversationData): self.system_prompt = "" """System prompt""" - self.session_data: dict = {} + self.state: dict = {} """Temporary data for plugins""" self.completion: Optional[str] = None @@ -52,39 +52,39 @@ class ConversationContext(ConversationData): return [msg for msg in self.history if msg["role"] == role] - def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext: + def create_sub_context(self, copy_store = False, copy_state = False) -> ConversationContext: sub_context = ConversationContext() sub_context.user_id = self.user_id if copy_store: sub_context.store = self.store.copy() - if copy_session_data: - sub_context.session_data = self.session_data.copy() + if copy_state: + sub_context.state = self.state.copy() return sub_context def __getitem__(self, key): - return self.session_data[key] + return self.state[key] def __setitem__(self, key, value): - self.session_data[key] = value + self.state[key] = value return value def __delitem__(self, key): - del self.session_data[key] + del self.state[key] def __contains__(self, key): - return key in self.session_data + return key in self.state def __iter__(self): - return iter(self.session_data) + return iter(self.state) def __len__(self): - return len(self.session_data) \ No newline at end of file + return len(self.state) \ No newline at end of file diff --git a/agentkit/flow.py b/agentkit/flow.py new file mode 100644 index 0000000..4c5ea0a --- /dev/null +++ b/agentkit/flow.py @@ -0,0 +1,126 @@ +import logging +from agentkit.context import ConversationContext +from agentkit.types import AgentKitFlowStep + + +class BreakLoopInterrupt(Exception): + pass + + +class AgentKitFlowExecutor: + def __init__(self, flow_script: list[AgentKitFlowStep]): + self.flow = flow_script + + def _execute_step(self, step: AgentKitFlowStep, context: ConversationContext): + """ + Execute a single flow step. + + Args: + step (dict): The step to execute. + + Returns: + Any: The result of the step execution. + """ + step_type = step["type"] + if step_type == "call": + return self._execute_call(step, context) + elif step_type == "if_else": + return self._execute_if_else(step, context) + elif step_type == "loop": + return self._execute_loop(step) + elif step_type == "break_loop": + return self._execute_break_loop(step) + else: + raise ValueError(f"Unsupported step type: {step_type}") + + def _execute_call(self, step: dict, context: ConversationContext): + """ + Execute a 'call' step. + + Args: + step (dict): The call step configuration. + + Returns: + Any: The result of the call. + """ + func_id = step.get("id") + config = {k: self._resolve_value(v) for k, v in step.get("config", {}).items()} + output_map = step.get("output_map", {}) + + # Simulate calling a function (replace with actual function calls if needed) + logging.info(f"Calling function {func_id} with config: {config}") + result = {key: f"mocked_value_{key}" for key in output_map.keys()} # Mocked output + + # Map outputs to variables + for key, var_name in output_map.items(): + self.variables[var_name] = result.get(key) + return result + + def _execute_if_else(self, step: dict, context: ConversationContext): + """ + Execute an 'if_else' step. + + Args: + step (dict): The if_else step configuration. + + Returns: + Any: The result of the executed branch. + """ + condition = step.get("condition") + condition_input = {k: self._resolve_value(v) for k, v in step.get("condition_input", {}).items()} + + # Evaluate the condition + condition_result = eval(condition, {}, condition_input) + + # Execute the appropriate branch + branch = step["true_branch"] if condition_result else step["false_branch"] + for sub_step in branch: + self._execute_step(sub_step, context) + + def _execute_loop(self, step: dict, context: ConversationContext): + """ + Execute a 'loop' step. + + Args: + step (dict): The loop step configuration. + + Returns: + Any: The result of the loop execution. + """ + loop_num = step.get("loop_num") + index_var = step.get("index_var") + loop_body = step.get("loop_body", []) + + for i in range(loop_num): + if index_var: + context.state[index_var] = i + + for sub_step in loop_body: + self._execute_step(sub_step, context) + + + def _execute_break_loop(self): + """ + Execute a 'break_loop' step. + + Args: + step (dict): The break_loop step configuration. + + Returns: + None + """ + raise BreakLoopInterrupt() + + + def execute(self, context: ConversationContext): + """ + Execute the entire flow. + + Returns: + None + """ + try: + for step in self.flow: + self._execute_step(step) + except BreakLoopInterrupt: + pass \ No newline at end of file diff --git a/agentkit/types.py b/agentkit/types.py index 27da6ba..7dcebee 100644 --- a/agentkit/types.py +++ b/agentkit/types.py @@ -49,7 +49,46 @@ class LLMFunctionResponse(TypedDict): """Directly return the response to the user, without further processing.""" -OnProgressCallback = Callable[[int, int], Any] +class AgentKitFlowStep(TypedDict): + type: str + comment: str + + +class AgentKitFlowCall(AgentKitFlowStep): + # type: "call" + id: str + comment: str + config: dict[str, Any] + output_map: dict[str, str] + + +class AgentKitFlowIfElse(AgentKitFlowStep): + # type: "if_else" + condition: str + condition_input: dict[str, Any] + true_branch: list + false_branch: list + + +class AgentKitLoop(AgentKitFlowStep): + # type: "loop" + loop_num: int + loop_body: list + index_var: Optional[str] + + +class AgentKitBreakLoop(AgentKitFlowStep): + # type: "break_loop" + pass + + +class AgentKitSetVar(AgentKitFlowStep): + # type: "set_var" + var_name: str + var_value_expr: str + + +OnProgressCallback = Callable[[str, int, int], Any] OnTaskStartCallback = Callable[[str, str], Any] OnTaskProgressCallback = Callable[[int, Optional[int], Optional[str]], Any] diff --git a/events/chat_complete_event.py b/events/chat_complete_event.py new file mode 100644 index 0000000..c22b683 --- /dev/null +++ b/events/chat_complete_event.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import Callable +from type_defs.chat_complete_task import ChatCompleteServiceResponse +from utils.program import run_listeners + +class ChatCompleteEvent: + def __init__(self): + self.on_tool_running: list[Callable[[str, str], None]] = [] + self.on_tool_output: list[Callable[[str], None]] = [] + self.on_message_output: list[Callable[[str], None]] = [] + + async def emit_tool_running(self, tool_name: str, running_state: str = "") -> None: + await run_listeners(self.on_tool_running, tool_name, running_state) + + async def emit_tool_output(self, output: str) -> None: + await run_listeners(self.on_tool_output, output) + + async def emit_message_output(self, output: str) -> None: + await run_listeners(self.on_message_output, output) + +class ChatCompleteTaskEvent(ChatCompleteEvent): + def __init__(self): + super().__init__() + self.on_finished: list[Callable[[str | None], None]] = [] + self.on_error: list[Callable[[str], None]] = [] + + async def emit_finished(self, result: ChatCompleteServiceResponse | None) -> None: + await run_listeners(self.on_finished, result) + + async def emit_error(self, ex: Exception) -> None: + await run_listeners(self.on_error, ex) \ No newline at end of file diff --git a/requirements-text2vec.txt b/requirements-text2vec.txt deleted file mode 100644 index 820eef8..0000000 --- a/requirements-text2vec.txt +++ /dev/null @@ -1,3 +0,0 @@ -text2vec>=1.2.9 ---index-url https://download.pytorch.org/whl/cpu -torch \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9fe4fc7..04cfa3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ PyJWT==2.6.0 aiohttp-sse-client2==0.3.0 OpenCC==1.1.1 event-emitter-asyncio==1.0.4 -toml==0.10.2 \ No newline at end of file +toml==0.10.2 +ollama==0.4.4 \ No newline at end of file diff --git a/server/controller/ChatComplete.py b/server/controller/ChatComplete.py index 87c0647..177dcf0 100644 --- a/server/controller/ChatComplete.py +++ b/server/controller/ChatComplete.py @@ -7,10 +7,11 @@ from server.controller.task.ChatCompleteTask import ChatCompleteTask from server.model.base import clone_model from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.toolbox_ui.conversation import ConversationHelper +from type_defs.chat_complete_task import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse from utils.local import noawait from aiohttp import web from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel -from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage +from service.chat_complete import calculate_point_usage from service.database import DatabaseService from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.tiktoken import TikTokenService @@ -248,6 +249,8 @@ class ChatComplete: @staticmethod @utils.web.token_auth async def get_point_usage(request: web.Request): + """计算ChatComplete所需积分""" + params = await utils.web.get_param(request, { "question": { "type": str, @@ -299,6 +302,8 @@ class ChatComplete: @staticmethod @utils.web.token_auth async def get_persona_list(request: web.Request): + """获取机器人列表""" + params = await utils.web.get_param(request, { "page": { "type": int, @@ -336,6 +341,8 @@ class ChatComplete: @staticmethod @utils.web.token_auth async def get_persona_info(request: web.Request): + """获取机器人信息""" + params = await utils.web.get_param(request, { "id": { "type": int, @@ -371,6 +378,8 @@ class ChatComplete: @staticmethod @utils.web.token_auth async def start_chat_complete(request: web.Request): + """开始ChatComplete""" + params = await utils.web.get_param(request, { "title": { "type": str, @@ -471,6 +480,7 @@ class ChatComplete: @staticmethod @utils.web.token_auth async def chat_complete_stream(request: web.Request): + """流式输出ChatComplete结果""" if not utils.web.is_websocket(request): return await utils.web.api_response(-1, error={ "code": "websocket-required", @@ -542,9 +552,28 @@ class ChatComplete: await ws.close() else: async def on_closed(): - task.on_message.remove(on_message) - task.on_finished.remove(on_finished) - task.on_error.remove(on_error) + task.events.on_tool_running.remove(on_tool_running) + task.events.on_tool_output.remove(on_tool_output) + task.events.on_message_output.remove(on_message) + task.events.on_finished.remove(on_finished) + task.events.on_error.remove(on_error) + + async def on_tool_running(tool_name: str, running_state: str): + try: + await ws.send_json({ + 'event': 'tool_run', + 'status': 1, + 'tool_name': tool_name, + 'running_state': running_state, + }) + except ConnectionResetError: + await on_closed() + + async def on_tool_output(output: str): + try: + await ws.send_str(">" + output) + except ConnectionResetError: + await on_closed() async def on_message(delta_message: str): try: @@ -552,15 +581,18 @@ class ChatComplete: except ConnectionResetError: await on_closed() - async def on_finished(result: ChatCompleteServiceResponse): + async def on_finished(result: ChatCompleteServiceResponse | None): try: ignored_keys = ["message"] response_result = { "point_usage": task.point_usage, } - for k, v in result.items(): - if k not in ignored_keys: - response_result[k] = v + + if result: + for k, v in result.items(): + if k not in ignored_keys: + response_result[k] = v + await ws.send_json({ 'event': 'finished', 'status': 1, @@ -587,9 +619,11 @@ class ChatComplete: except ConnectionResetError: await on_closed() - task.on_message.append(on_message) - task.on_finished.append(on_finished) - task.on_error.append(on_error) + task.events.on_tool_running.append(on_tool_running) + task.events.on_tool_output.append(on_tool_output) + task.events.on_message_output.append(on_message) + task.events.on_finished.append(on_finished) + task.events.on_error.append(on_error) # Send received message await ws.send_json({ diff --git a/server/controller/task/ChatCompleteTask.py b/server/controller/task/ChatCompleteTask.py index a3a953f..c516318 100644 --- a/server/controller/task/ChatCompleteTask.py +++ b/server/controller/task/ChatCompleteTask.py @@ -2,14 +2,14 @@ from __future__ import annotations import sys import time import traceback +from events.chat_complete_event import ChatCompleteEvent, ChatCompleteTaskEvent from libs.config import Config from server.model.chat_complete.bot_persona import BotPersonaHelper +from type_defs.chat_complete_task import ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse from utils.local import noawait from typing import Optional, Callable, Union from service.chat_complete import ( ChatCompleteService, - ChatCompleteServicePrepareResponse, - ChatCompleteServiceResponse, calculate_point_usage, ) from service.database import DatabaseService @@ -30,9 +30,7 @@ class ChatCompleteTask: self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False ): self.task_id = utils.web.generate_uuid() - self.on_message: list[Callable] = [] - self.on_finished: list[Callable] = [] - self.on_error: list[Callable] = [] + self.events: ChatCompleteTaskEvent = ChatCompleteTaskEvent() self.chunks: list[str] = [] self.chat_complete_service: ChatCompleteService @@ -111,45 +109,10 @@ class ChatCompleteTask: await self.end() raise e - async def _on_message(self, delta_message: str): - self.chunks.append(delta_message) - - for callback in self.on_message: - try: - await callback(delta_message) - except Exception as e: - print( - "Error while processing on_message callback: %s" % e, - file=sys.stderr, - ) - traceback.print_exc() - - async def _on_finished(self): - for callback in self.on_finished: - try: - await callback(self.result) - except Exception as e: - print( - "Error while processing on_finished callback: %s" % e, - file=sys.stderr, - ) - traceback.print_exc() - - async def _on_error(self, err: Exception): - self.error = err - for callback in self.on_error: - try: - await callback(err) - except Exception as e: - print( - "Error while processing on_error callback: %s" % e, file=sys.stderr - ) - traceback.print_exc() - async def run(self) -> ChatCompleteServiceResponse: chat_complete_tasks[self.task_id] = self try: - chat_res = await self.chat_complete.finish_chat_complete(self._on_message) + chat_res = await self.chat_complete.finish_chat_complete(event_container=self.events) await self.chat_complete.set_latest_point_usage(self.point_usage) @@ -160,7 +123,7 @@ class ChatCompleteTask: self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分 ) - await self._on_finished() + await self.events.emit_finished(chat_res) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" @@ -172,7 +135,7 @@ class ChatCompleteTask: self.transatcion_id, error=err_msg ) - await self._on_error(e) + await self.events.emit_error(e) finally: await self.end() diff --git a/service/chat_complete.py b/service/chat_complete.py index 9f1a401..bc098f5 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -4,6 +4,7 @@ import time import traceback from typing import Optional, Tuple, TypedDict +from events.chat_complete_event import ChatCompleteEvent from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.chat_complete.conversation import ( ConversationChunkHelper, @@ -13,6 +14,7 @@ import sys from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel from libs.config import Config +from type_defs.chat_complete_task import ChatCompleteQuestionTooLongException, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse import utils.config, utils.web from aiohttp import web @@ -24,27 +26,6 @@ from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException from service.tiktoken import TikTokenService -class ChatCompleteQuestionTooLongException(Exception): - def __init__(self, tokens_limit: int, tokens_current: int): - super().__init__(f"Question too long: {tokens_current} > {tokens_limit}") - self.tokens_limit = tokens_limit - self.tokens_current = tokens_current - -class ChatCompleteServicePrepareResponse(TypedDict): - extract_doc: list - question_tokens: int - conversation_id: int - chunk_id: int - api_id: str - -class ChatCompleteServiceResponse(TypedDict): - message: str - message_tokens: int - total_tokens: int - finish_reason: str - question_message_id: str - response_message_id: str - delta_data: dict class ChatCompleteService: def __init__(self, dbs: DatabaseService, title: str): @@ -285,7 +266,7 @@ class ChatCompleteService: ) async def finish_chat_complete( - self, on_message: Optional[callable] = None + self, event_container: ChatCompleteEvent = None ) -> ChatCompleteServiceResponse: delta_data = {} @@ -311,7 +292,6 @@ class ChatCompleteService: doc_prompt = utils.config.get_prompt( "extracted_doc", "prompt", {"content": doc_prompt_content} ) - message_log.append({"role": "user", "content": doc_prompt}) system_prompt = self.chat_system_prompt if system_prompt is None: @@ -323,13 +303,15 @@ class ChatCompleteService: system_prompt = utils.config.format_prompt(system_prompt) # Start chat complete - if on_message is not None: + if event_container is not None: response = await self.openai_api.chat_complete_stream( - self.question, system_prompt, self.model, message_log, on_message + self.question, system_prompt, self.model, message_log, + doc_prompt=doc_prompt, event_container=event_container ) else: response = await self.openai_api.chat_complete( - self.question, system_prompt, self.model, message_log + self.question, system_prompt, self.model, message_log, + doc_prompt=doc_prompt ) description = response["message"][0:150] diff --git a/service/embedding_search.py b/service/embedding_search.py index 288c218..0e68efc 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -298,15 +298,13 @@ class EmbeddingSearchService: query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc) query_embedding = query_doc[0]["embedding"] - print(query_embedding) - if query_embedding is None: return [], token_usage res = await self.page_index.search_text_embedding( query_embedding, in_collection, limit, self.page_id ) - print(res) + if res: filtered = [] for one in res: diff --git a/service/openai_api.py b/service/openai_api.py index 1b0af29..d808054 100644 --- a/service/openai_api.py +++ b/service/openai_api.py @@ -9,6 +9,7 @@ import numpy as np from aiohttp_sse_client2 import client as sse_client from service.tiktoken import TikTokenService +from events.chat_complete_event import ChatCompleteEvent AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview" AZURE_EMBEDDING_API_VERSION = "2023-05-15" @@ -164,7 +165,7 @@ class OpenAIApi: return (doc_list, token_usage) - async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]: + async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None) -> list[ChatCompleteMessageLog]: summaryContent = None messageList: list[ChatCompleteMessageLog] = [] @@ -177,13 +178,17 @@ class OpenAIApi: if summaryContent: system_prompt += "\n\n" + summaryContent - messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt)) + messageList.insert(0, ChatCompleteMessageLog(role="system", content=system_prompt)) + + if doc_prompt: + question = doc_prompt + "\n\n" + question + messageList.append(ChatCompleteMessageLog(role="user", content=question)) return messageList - async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None): - messageList = await self.make_message_list(question, system_prompt, conversation) + async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None, user = None): + messageList = await self.make_message_list(question, system_prompt, conversation, doc_prompt) url = self.get_url("chat/completions") @@ -231,10 +236,11 @@ class OpenAIApi: return None - async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None): + async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], doc_prompt: str = None, + event_container: ChatCompleteEvent = None, user = None): tiktoken = await TikTokenService.create() - messageList = await self.make_message_list(question, system_prompt, conversation) + messageList = await self.make_message_list(question, system_prompt, conversation, doc_prompt) prompt_tokens = 0 for message in messageList: @@ -264,6 +270,9 @@ class OpenAIApi: res_message: list[str] = [] finish_reason = None + # 模型是否在深度思考中 + in_reasoning = False + async with sse_client.EventSource( url, option={ @@ -296,9 +305,9 @@ class OpenAIApi: finish_reason = choice["finish_reason"] delta_content = choice["delta"] - if "content" in delta_content: + if "content" in delta_content and delta_content["content"] is not None: delta_message: str = delta_content["content"] - + # Skip empty lines before content if not content_started: if delta_message.replace("\n", "") == "": @@ -311,9 +320,24 @@ class OpenAIApi: # if config.DEBUG: # print(delta_message, end="", flush=True) - if on_message is not None: - await on_message(delta_message) - + await event_container.emit_message_output(delta_message) + elif "reasoning_content" in delta_content and delta_content["reasoning_content"] is not None: + # 处理DeepSeek的深度思考内容 + delta_message: str = delta_content["reasoning_content"] + + if not in_reasoning: + await event_container.emit_tool_running("深度思考") + in_reasoning = True + + # Skip empty lines before content + if not content_started: + if delta_message.replace("\n", "") == "": + continue + else: + content_started = True + + await event_container.emit_tool_output(delta_message) + if finish_reason is not None: break diff --git a/service/text_embedding.py b/service/text_embedding.py index 7f49f06..00ee67e 100644 --- a/service/text_embedding.py +++ b/service/text_embedding.py @@ -6,26 +6,24 @@ import asyncio import random import threading from typing import Callable, Optional, TypedDict -import torch -from text2vec import SentenceModel +import ollama from utils.local import loop from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService -BERT_EMBEDDING_QUEUE_TIMEOUT = 1 +LOCAL_EMBEDDING_QUEUE_TIMEOUT = 1 -class Text2VecEmbeddingQueueTaskInfo(TypedDict): +class OllamaEmbeddingQueueTaskInfo(TypedDict): task_id: int text: str - embedding: torch.Tensor + embedding: list -class Text2VecEmbeddingQueue: +class OllamaEmbeddingQueue: def __init__(self, model: str) -> None: self.model_name = model - self.embedding_model: SentenceModel = None - self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {} - self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = [] + self.task_map: dict[int, OllamaEmbeddingQueueTaskInfo] = {} + self.task_list: list[OllamaEmbeddingQueueTaskInfo] = [] self.lock = threading.Lock() self.thread: Optional[threading.Thread] = None @@ -33,7 +31,7 @@ class Text2VecEmbeddingQueue: def post_init(self): - self.embedding_model = SentenceModel(self.model_name) + pass async def get_embeddings(self, text: str): @@ -81,13 +79,13 @@ class Text2VecEmbeddingQueue: task = self.task_list.pop(0) if task is not None: - embeddings = self.embedding_model.encode([task["text"]]) + res = ollama.embeddings(model=self.model_name, prompt=task["text"]) with self.lock: - task["embedding"] = embeddings[0] + task["embedding"] = res.embedding last_task_time = time.time() - elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: + elif last_task_time is not None and current_time > last_task_time + LOCAL_EMBEDDING_QUEUE_TIMEOUT: self.thread = None self.running = False running = False @@ -107,7 +105,7 @@ class TextEmbeddingService: def __init__(self): self.tiktoken: TikTokenService = None - self.text2vec_queue: Text2VecEmbeddingQueue = None + self.ollama_queue: OllamaEmbeddingQueue = None self.openai_api: OpenAIApi = None @staticmethod @@ -120,22 +118,22 @@ class TextEmbeddingService: async def init(self): self.tiktoken = await TikTokenService.create() - self.embedding_type = Config.get("embedding.type", "text2vec") + self.embedding_type = Config.get("embedding.type", "ollama") - if self.embedding_type == "text2vec": - embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese") - self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model) + if self.embedding_type == "ollama": + embedding_model = Config.get("embedding.embedding_model", "shaw/dmeta-embedding-zh") + self.ollama_queue = OllamaEmbeddingQueue(model=embedding_model) + await loop.run_in_executor(None, self.ollama_queue.post_init) elif self.embedding_type == "openai": api_id = Config.get("embedding.api_id") self.openai_api: OpenAIApi = await OpenAIApi.create(api_id) - await loop.run_in_executor(None, self.text2vec_queue.post_init) - async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): + async def get_ollama_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): total_token_usage = 0 for index, doc in enumerate(doc_list): text = doc["text"] - embedding = await self.text2vec_queue.get_embeddings(text) + embedding = await self.ollama_queue.get_embeddings(text) doc["embedding"] = embedding total_token_usage += await self.tiktoken.get_tokens(text) @@ -165,7 +163,7 @@ class TextEmbeddingService: text = " ".join(new_lines) doc["text"] = text - if self.embedding_type == "text2vec": - return await self.get_text2vec_embeddings(res_doc_list, on_index_progress) + if self.embedding_type == "ollama": + return await self.get_ollama_embeddings(res_doc_list, on_index_progress) elif self.embedding_type == "openai": return await self.openai_api.get_embeddings(res_doc_list, on_index_progress) \ No newline at end of file diff --git a/service/tiktoken.py b/service/tiktoken.py index 17dd4cf..789fded 100644 --- a/service/tiktoken.py +++ b/service/tiktoken.py @@ -16,7 +16,7 @@ class TikTokenService: self.enc: tiktoken_async.Encoding = None async def init(self): - self.enc = await tiktoken_async.encoding_for_model("gpt-3.5-turbo") + self.enc = await tiktoken_async.encoding_for_model("gpt-4o") async def get_tokens(self, text: str): encoded = self.enc.encode(text) diff --git a/test/text2vec_embedding_queue.py b/test/text2vec_embedding_queue.py index b9ce694..5623472 100644 --- a/test/text2vec_embedding_queue.py +++ b/test/text2vec_embedding_queue.py @@ -2,12 +2,12 @@ import asyncio import time import base as _ from utils.local import loop, noawait -from service.text_embedding import Text2VecEmbeddingQueue +from service.text_embedding import OllamaEmbeddingQueue async def main(): embedding_list = [] queue = [] - text2vec_queue = Text2VecEmbeddingQueue("shibing624/text2vec-base-chinese") + text2vec_queue = OllamaEmbeddingQueue("shibing624/text2vec-base-chinese") start_time = time.time() diff --git a/type_defs/chat_complete_task.py b/type_defs/chat_complete_task.py new file mode 100644 index 0000000..633bc79 --- /dev/null +++ b/type_defs/chat_complete_task.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from typing_extensions import TypedDict + + +class ChatCompleteQuestionTooLongException(Exception): + def __init__(self, tokens_limit: int, tokens_current: int): + super().__init__(f"Question too long: {tokens_current} > {tokens_limit}") + self.tokens_limit = tokens_limit + self.tokens_current = tokens_current + +class ChatCompleteServicePrepareResponse(TypedDict): + extract_doc: list + question_tokens: int + conversation_id: int + chunk_id: int + api_id: str + +class ChatCompleteServiceResponse(TypedDict): + message: str + message_tokens: int + total_tokens: int + finish_reason: str + question_message_id: str + response_message_id: str + delta_data: dict \ No newline at end of file diff --git a/utils/program.py b/utils/program.py new file mode 100644 index 0000000..b6765a1 --- /dev/null +++ b/utils/program.py @@ -0,0 +1,19 @@ +import asyncio +import sys +import traceback +from typing import Coroutine + + +async def run_listeners(listeners: list, *args, **kwargs) -> Coroutine[None, None, None]: + for listener in listeners: + try: + res = listener(*args, **kwargs) + if asyncio.iscoroutine(res): + await res + except Exception as ex: + print( + "Error while processing callback: %s" % ex, + file=sys.stderr, + ) + traceback.print_exc() + return None \ No newline at end of file