mirror of https://github.com/langgenius/dify.git
add icon return for tool node in workflow event stream
This commit is contained in:
parent
4235baf493
commit
0a0d9565ac
|
|
@ -196,6 +196,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
|
|
@ -23,7 +23,9 @@ from core.app.entities.task_entities import (
|
|||
)
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
|
|
@ -321,15 +323,18 @@ class WorkflowCycleManage:
|
|||
)
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeStartStreamResponse:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
return NodeStartStreamResponse(
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
|
|
@ -344,6 +349,17 @@ class WorkflowCycleManage:
|
|||
)
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import mimetypes
|
|||
from os import listdir, path
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
|
|
@ -43,15 +45,16 @@ logger = logging.getLogger(__name__)
|
|||
_builtin_providers = {}
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
def invoke(
|
||||
provider: str,
|
||||
tool_id: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
provider: str,
|
||||
tool_id: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
invoke the assistant
|
||||
|
|
@ -80,7 +83,7 @@ class ToolManager:
|
|||
provider_entity = provider_class()
|
||||
|
||||
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
|
||||
global _builtin_providers
|
||||
|
|
@ -96,9 +99,9 @@ class ToolManager:
|
|||
|
||||
if provider not in _builtin_providers:
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
|
||||
|
||||
|
||||
return _builtin_providers[provider]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
|
||||
"""
|
||||
|
|
@ -113,10 +116,10 @@ class ToolManager:
|
|||
tool = provider_controller.get_tool(tool_name)
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool
|
||||
|
||||
|
|
@ -137,11 +140,11 @@ class ToolManager:
|
|||
raise NotImplementedError('app provider not implemented')
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
|
||||
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
|
||||
agent_callback: DifyAgentCallbackHandler = None) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
|
|
@ -170,7 +173,7 @@ class ToolManager:
|
|||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
|
||||
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
|
|
@ -183,11 +186,11 @@ class ToolManager:
|
|||
'credentials': decrypted_credentials,
|
||||
'runtime_parameters': {}
|
||||
}, agent_callback=agent_callback)
|
||||
|
||||
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
|
||||
|
||||
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
|
||||
|
||||
# decrypt the credentials
|
||||
|
|
@ -232,7 +235,8 @@ class ToolManager:
|
|||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter_rule.options))
|
||||
if parameter_value not in options:
|
||||
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
|
||||
raise ValueError(
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
|
|
@ -249,7 +253,8 @@ class ToolManager:
|
|||
parameter_value = int(parameter_value)
|
||||
elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_value = bool(parameter_value)
|
||||
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
|
||||
ToolParameter.ToolParameterType.STRING]:
|
||||
parameter_value = str(parameter_value)
|
||||
elif parameter_rule.type == ToolParameter.ToolParameterType:
|
||||
parameter_value = str(parameter_value)
|
||||
|
|
@ -259,12 +264,14 @@ class ToolManager:
|
|||
return parameter_value
|
||||
|
||||
@staticmethod
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity,
|
||||
agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
|
@ -275,7 +282,7 @@ class ToolManager:
|
|||
# save tool parameter to tool entity memory
|
||||
value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -335,11 +342,12 @@ class ToolManager:
|
|||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider)
|
||||
|
||||
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
|
||||
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
|
||||
provider_controller.identity.icon)
|
||||
# check if the icon exists
|
||||
if not path.exists(absolute_path):
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
|
||||
|
||||
|
||||
# get the mime type
|
||||
mime_type, _ = mimetypes.guess_type(absolute_path)
|
||||
mime_type = mime_type or 'application/octet-stream'
|
||||
|
|
@ -353,7 +361,7 @@ class ToolManager:
|
|||
# use cache first
|
||||
if len(_builtin_providers) > 0:
|
||||
return list(_builtin_providers.values())
|
||||
|
||||
|
||||
builtin_providers: list[BuiltinToolProviderController] = []
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
|
|
@ -367,7 +375,7 @@ class ToolManager:
|
|||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
|
||||
|
|
@ -378,7 +386,7 @@ class ToolManager:
|
|||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
|
||||
return builtin_providers
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
|
||||
"""
|
||||
|
|
@ -403,7 +411,7 @@ class ToolManager:
|
|||
model_providers.append(ModelToolProviderController.from_db(configuration))
|
||||
|
||||
return model_providers
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
|
||||
"""
|
||||
|
|
@ -419,7 +427,7 @@ class ToolManager:
|
|||
configuration = configurations.get(provider_name)
|
||||
if configuration is None:
|
||||
raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
|
||||
|
||||
|
||||
return ModelToolProviderController.from_db(configuration)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -438,13 +446,13 @@ class ToolManager:
|
|||
|
||||
if tool_name not in _builtin_tools_labels:
|
||||
return None
|
||||
|
||||
|
||||
return _builtin_tools_labels[tool_name]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def user_list_providers(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> list[UserToolProvider]:
|
||||
result_providers: dict[str, UserToolProvider] = {}
|
||||
|
||||
|
|
@ -454,8 +462,9 @@ class ToolManager:
|
|||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
|
||||
find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider),
|
||||
None)
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
|
|
@ -478,7 +487,7 @@ class ToolManager:
|
|||
# get db api providers
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
|
||||
for db_api_provider in db_api_providers:
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(
|
||||
db_provider=db_api_provider,
|
||||
|
|
@ -490,9 +499,10 @@ class ToolManager:
|
|||
result_providers[db_api_provider.name] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
|
||||
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiBasedToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
|
|
@ -507,14 +517,15 @@ class ToolManager:
|
|||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
|
||||
|
||||
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
provider,
|
||||
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
controller.load_bundled_tools(provider.tools)
|
||||
|
||||
return controller, provider.credentials
|
||||
|
||||
|
||||
@staticmethod
|
||||
def user_get_api_provider(provider: str, tenant_id: str) -> dict:
|
||||
"""
|
||||
|
|
@ -530,7 +541,7 @@ class ToolManager:
|
|||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider.credentials_str) or {}
|
||||
except:
|
||||
|
|
@ -562,4 +573,36 @@ class ToolManager:
|
|||
'description': provider.description,
|
||||
'credentials': masked_credentials,
|
||||
'privacy_policy': provider.privacy_policy
|
||||
}))
|
||||
}))
|
||||
|
||||
@staticmethod
|
||||
def get_tool_icon(tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == 'builtin':
|
||||
return (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
+ provider_id
|
||||
+ "/icon")
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
return json.loads(provider.icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
|
|
|||
Loading…
Reference in New Issue