diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a83000a0bc..edef4e249b 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -196,6 +196,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield from self._generate_stream_outputs_when_node_started() yield self._workflow_node_start_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 3e0a9e5e5c..f926d75968 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -162,6 +162,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa elif isinstance(event, QueueNodeStartedEvent): workflow_node_execution = self._handle_node_start(event) yield self._workflow_node_start_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index b2c80ec22c..29ea7bc935 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -229,6 +229,7 @@ class NodeStartStreamResponse(StreamResponse): predecessor_node_id: Optional[str] = None inputs: Optional[dict] = None created_at: int + extras: dict = {} event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 7077bab2fb..8c532d4a5d 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,7 +1,7 @@ import json import time from datetime import datetime -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -23,7 +23,9 @@ from core.app.entities.task_entities import ( ) from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account @@ -321,15 +323,18 @@ class WorkflowCycleManage: ) ) - def _workflow_node_start_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ + def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution) \ -> NodeStartStreamResponse: """ Workflow node start to stream response. + :param event: queue node started event :param task_id: task id :param workflow_node_execution: workflow node execution :return: """ - return NodeStartStreamResponse( + response = NodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, data=NodeStartStreamResponse.Data( @@ -344,6 +349,17 @@ class WorkflowCycleManage: ) ) + # extras logic + if event.node_type == NodeType.TOOL: + node_data = cast(ToolNodeData, event.node_data) + response.data.extras['icon'] = ToolManager.get_tool_icon( + tenant_id=self._application_generate_entity.app_config.tenant_id, + provider_type=node_data.provider_type, + provider_id=node_data.provider_id + ) + + return response + def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ -> NodeFinishStreamResponse: """ diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 3a0278bc01..658eed812c 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,6 +4,8 @@ import mimetypes from os import listdir, path from typing import Any, Union +from flask import current_app + from core.agent.entities import AgentToolEntity from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.model_runtime.entities.message_entities import PromptMessage @@ -43,15 +45,16 @@ logger = logging.getLogger(__name__) _builtin_providers = {} _builtin_tools_labels = {} + class ToolManager: @staticmethod def invoke( - provider: str, - tool_id: str, - tool_name: str, - tool_parameters: dict[str, Any], - credentials: dict[str, Any], - prompt_messages: list[PromptMessage], + provider: str, + tool_id: str, + tool_name: str, + tool_parameters: dict[str, Any], + credentials: dict[str, Any], + prompt_messages: list[PromptMessage], ) -> list[ToolInvokeMessage]: """ invoke the assistant @@ -80,7 +83,7 @@ class ToolManager: provider_entity = provider_class() return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages) - + @staticmethod def get_builtin_provider(provider: str) -> BuiltinToolProviderController: global _builtin_providers @@ -96,9 +99,9 @@ class ToolManager: if provider not in _builtin_providers: raise ToolProviderNotFoundError(f'builtin provider {provider} not found') - + return _builtin_providers[provider] - + @staticmethod def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool: """ @@ -113,10 +116,10 @@ class ToolManager: tool = provider_controller.get_tool(tool_name) return tool - + @staticmethod def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool]: """ get the tool @@ -137,11 +140,11 @@ class ToolManager: raise NotImplementedError('app provider not implemented') else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') - + @staticmethod - def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, + def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, agent_callback: DifyAgentCallbackHandler = None) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -170,7 +173,7 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found') - + # decrypt the credentials credentials = builtin_provider.credentials controller = ToolManager.get_builtin_provider(provider_name) @@ -183,11 +186,11 @@ class ToolManager: 'credentials': decrypted_credentials, 'runtime_parameters': {} }, agent_callback=agent_callback) - + elif provider_type == 'api': if tenant_id is None: raise ValueError('tenant id is required for api provider') - + api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) # decrypt the credentials @@ -232,7 +235,8 @@ class ToolManager: # check if tool_parameter_config in options options = list(map(lambda x: x.value, parameter_rule.options)) if parameter_value not in options: - raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + raise ValueError( + f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") # convert tool parameter config to correct type try: @@ -249,7 +253,8 @@ class ToolManager: parameter_value = int(parameter_value) elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: parameter_value = bool(parameter_value) - elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, + ToolParameter.ToolParameterType.STRING]: parameter_value = str(parameter_value) elif parameter_rule.type == ToolParameter.ToolParameterType: parameter_value = str(parameter_value) @@ -259,12 +264,14 @@ class ToolManager: return parameter_value @staticmethod - def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: + def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, + agent_callback: DifyAgentCallbackHandler) -> Tool: """ get the agent tool runtime """ tool_entity = ToolManager.get_tool_runtime( - provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, + provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, + tool_name=agent_tool.tool_name, tenant_id=tenant_id, agent_callback=agent_callback ) @@ -275,7 +282,7 @@ class ToolManager: # save tool parameter to tool entity memory value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) runtime_parameters[parameter.name] = value - + # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( tenant_id=tenant_id, @@ -335,11 +342,12 @@ class ToolManager: # get provider provider_controller = ToolManager.get_builtin_provider(provider) - absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon) + absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', + provider_controller.identity.icon) # check if the icon exists if not path.exists(absolute_path): raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') - + # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) mime_type = mime_type or 'application/octet-stream' @@ -353,7 +361,7 @@ class ToolManager: # use cache first if len(_builtin_providers) > 0: return list(_builtin_providers.values()) - + builtin_providers: list[BuiltinToolProviderController] = [] for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): if provider.startswith('__'): @@ -367,7 +375,7 @@ class ToolManager: provider_class = load_single_subclass_from_source( module_name=f'core.tools.provider.builtin.{provider}.{provider}', script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), + 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) builtin_providers.append(provider_class()) @@ -378,7 +386,7 @@ class ToolManager: _builtin_tools_labels[tool.identity.name] = tool.identity.label return builtin_providers - + @staticmethod def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]: """ @@ -403,7 +411,7 @@ class ToolManager: model_providers.append(ModelToolProviderController.from_db(configuration)) return model_providers - + @staticmethod def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController: """ @@ -419,7 +427,7 @@ class ToolManager: configuration = configurations.get(provider_name) if configuration is None: raise ToolProviderNotFoundError(f'model provider {provider_name} not found') - + return ModelToolProviderController.from_db(configuration) @staticmethod @@ -438,13 +446,13 @@ class ToolManager: if tool_name not in _builtin_tools_labels: return None - + return _builtin_tools_labels[tool_name] - + @staticmethod def user_list_providers( - user_id: str, - tenant_id: str, + user_id: str, + tenant_id: str, ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} @@ -454,8 +462,9 @@ class ToolManager: # get db builtin providers db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ filter(BuiltinToolProvider.tenant_id == tenant_id).all() - - find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), None) + + find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), + None) # append builtin providers for provider in builtin_providers: @@ -478,7 +487,7 @@ class ToolManager: # get db api providers db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ filter(ApiToolProvider.tenant_id == tenant_id).all() - + for db_api_provider in db_api_providers: provider_controller = ToolTransformService.api_provider_to_controller( db_provider=db_api_provider, @@ -490,9 +499,10 @@ class ToolManager: result_providers[db_api_provider.name] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) - + @staticmethod - def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]: + def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ + ApiBasedToolProviderController, dict[str, Any]]: """ get the api provider @@ -507,14 +517,15 @@ class ToolManager: if provider is None: raise ToolProviderNotFoundError(f'api provider {provider_id} not found') - + controller = ApiBasedToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider, + ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) controller.load_bundled_tools(provider.tools) return controller, provider.credentials - + @staticmethod def user_get_api_provider(provider: str, tenant_id: str) -> dict: """ @@ -530,7 +541,7 @@ class ToolManager: if provider is None: raise ValueError(f'you have not added provider {provider}') - + try: credentials = json.loads(provider.credentials_str) or {} except: @@ -562,4 +573,36 @@ class ToolManager: 'description': provider.description, 'credentials': masked_credentials, 'privacy_policy': provider.privacy_policy - })) \ No newline at end of file + })) + + @staticmethod + def get_tool_icon(tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: + """ + get the tool icon + + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: + """ + provider_type = provider_type + provider_id = provider_id + if provider_type == 'builtin': + return (current_app.config.get("CONSOLE_API_URL") + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon") + elif provider_type == 'api': + try: + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.id == provider_id + ) + return json.loads(provider.icon) + except: + return { + "background": "#252525", + "content": "\ud83d\ude01" + } + else: + raise ValueError(f"provider type {provider_type} not found")