add icon return for tool node in workflow event stream

This commit is contained in:
takatost 2024-03-28 17:26:09 +08:00
parent 4235baf493
commit 0a0d9565ac
5 changed files with 108 additions and 46 deletions

View File

@ -196,6 +196,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)

View File

@ -162,6 +162,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)

View File

@ -229,6 +229,7 @@ class NodeStartStreamResponse(StreamResponse):
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
created_at: int
extras: dict = {}
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str

View File

@ -1,7 +1,7 @@
import json
import time
from datetime import datetime
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@ -23,7 +23,9 @@ from core.app.entities.task_entities import (
)
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.account import Account
@ -321,15 +323,18 @@ class WorkflowCycleManage:
)
)
def _workflow_node_start_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution) \
-> NodeStartStreamResponse:
"""
Workflow node start to stream response.
:param event: queue node started event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
return NodeStartStreamResponse(
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeStartStreamResponse.Data(
@ -344,6 +349,17 @@ class WorkflowCycleManage:
)
)
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras['icon'] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id
)
return response
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
-> NodeFinishStreamResponse:
"""

View File

@ -4,6 +4,8 @@ import mimetypes
from os import listdir, path
from typing import Any, Union
from flask import current_app
from core.agent.entities import AgentToolEntity
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.model_runtime.entities.message_entities import PromptMessage
@ -43,15 +45,16 @@ logger = logging.getLogger(__name__)
_builtin_providers = {}
_builtin_tools_labels = {}
class ToolManager:
@staticmethod
def invoke(
provider: str,
tool_id: str,
tool_name: str,
tool_parameters: dict[str, Any],
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
provider: str,
tool_id: str,
tool_name: str,
tool_parameters: dict[str, Any],
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
) -> list[ToolInvokeMessage]:
"""
invoke the assistant
@ -80,7 +83,7 @@ class ToolManager:
provider_entity = provider_class()
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
@staticmethod
def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
global _builtin_providers
@ -96,9 +99,9 @@ class ToolManager:
if provider not in _builtin_providers:
raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
return _builtin_providers[provider]
@staticmethod
def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
"""
@ -113,10 +116,10 @@ class ToolManager:
tool = provider_controller.get_tool(tool_name)
return tool
@staticmethod
def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-> Union[BuiltinTool, ApiTool]:
-> Union[BuiltinTool, ApiTool]:
"""
get the tool
@ -137,11 +140,11 @@ class ToolManager:
raise NotImplementedError('app provider not implemented')
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]:
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@ -170,7 +173,7 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
@ -183,11 +186,11 @@ class ToolManager:
'credentials': decrypted_credentials,
'runtime_parameters': {}
}, agent_callback=agent_callback)
elif provider_type == 'api':
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials
@ -232,7 +235,8 @@ class ToolManager:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter_rule.options))
if parameter_value not in options:
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
# convert tool parameter config to correct type
try:
@ -249,7 +253,8 @@ class ToolManager:
parameter_value = int(parameter_value)
elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_value = bool(parameter_value)
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
ToolParameter.ToolParameterType.STRING]:
parameter_value = str(parameter_value)
elif parameter_rule.type == ToolParameter.ToolParameterType:
parameter_value = str(parameter_value)
@ -259,12 +264,14 @@ class ToolManager:
return parameter_value
@staticmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity,
agent_callback: DifyAgentCallbackHandler) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
@ -275,7 +282,7 @@ class ToolManager:
# save tool parameter to tool entity memory
value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters)
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
@ -335,11 +342,12 @@ class ToolManager:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider)
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
provider_controller.identity.icon)
# check if the icon exists
if not path.exists(absolute_path):
raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
# get the mime type
mime_type, _ = mimetypes.guess_type(absolute_path)
mime_type = mime_type or 'application/octet-stream'
@ -353,7 +361,7 @@ class ToolManager:
# use cache first
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
builtin_providers: list[BuiltinToolProviderController] = []
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
@ -367,7 +375,7 @@ class ToolManager:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())
@ -378,7 +386,7 @@ class ToolManager:
_builtin_tools_labels[tool.identity.name] = tool.identity.label
return builtin_providers
@staticmethod
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
"""
@ -403,7 +411,7 @@ class ToolManager:
model_providers.append(ModelToolProviderController.from_db(configuration))
return model_providers
@staticmethod
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
"""
@ -419,7 +427,7 @@ class ToolManager:
configuration = configurations.get(provider_name)
if configuration is None:
raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
return ModelToolProviderController.from_db(configuration)
@staticmethod
@ -438,13 +446,13 @@ class ToolManager:
if tool_name not in _builtin_tools_labels:
return None
return _builtin_tools_labels[tool_name]
@staticmethod
def user_list_providers(
user_id: str,
tenant_id: str,
user_id: str,
tenant_id: str,
) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {}
@ -454,8 +462,9 @@ class ToolManager:
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), None)
find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider),
None)
# append builtin providers
for provider in builtin_providers:
@ -478,7 +487,7 @@ class ToolManager:
# get db api providers
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
for db_api_provider in db_api_providers:
provider_controller = ToolTransformService.api_provider_to_controller(
db_provider=db_api_provider,
@ -490,9 +499,10 @@ class ToolManager:
result_providers[db_api_provider.name] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[
ApiBasedToolProviderController, dict[str, Any]]:
"""
get the api provider
@ -507,14 +517,15 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
controller = ApiBasedToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
provider,
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
controller.load_bundled_tools(provider.tools)
return controller, provider.credentials
@staticmethod
def user_get_api_provider(provider: str, tenant_id: str) -> dict:
"""
@ -530,7 +541,7 @@ class ToolManager:
if provider is None:
raise ValueError(f'you have not added provider {provider}')
try:
credentials = json.loads(provider.credentials_str) or {}
except:
@ -562,4 +573,36 @@ class ToolManager:
'description': provider.description,
'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy
}))
}))
@staticmethod
def get_tool_icon(tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
"""
get the tool icon
:param tenant_id: the id of the tenant
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:return:
"""
provider_type = provider_type
provider_id = provider_id
if provider_type == 'builtin':
return (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon")
elif provider_type == 'api':
try:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.id == provider_id
)
return json.loads(provider.icon)
except:
return {
"background": "#252525",
"content": "\ud83d\ude01"
}
else:
raise ValueError(f"provider type {provider_type} not found")