From ec1c4efca94d06fb25364a4ab53b9e9dbf954779 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Sun, 25 May 2025 23:09:01 +0800 Subject: [PATCH] r2 --- .../datasource/__base/datasource_plugin.py | 2 +- .../datasource/__base/datasource_provider.py | 4 +- .../local_file/local_file_plugin.py | 2 +- .../online_document/online_document_plugin.py | 2 +- .../website_crawl/website_crawl_plugin.py | 2 +- .../website_crawl/website_crawl_provider.py | 2 +- api/core/plugin/impl/datasource.py | 69 +++++++--- .../workflow/nodes/datasource/__init__.py | 2 +- .../nodes/datasource/datasource_node.py | 130 ++++++++++-------- .../workflow/nodes/datasource/entities.py | 27 +--- .../knowledge_index/knowledge_index_node.py | 5 +- api/core/workflow/nodes/node_mapping.py | 10 ++ 12 files changed, 147 insertions(+), 110 deletions(-) diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index d8681b6491..5a13d17843 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -20,7 +20,7 @@ class DatasourcePlugin(ABC): self.runtime = runtime @abstractmethod - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: """ returns the type of the datasource provider """ diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index bae39dc8c7..045ca64872 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError class DatasourcePluginProviderController(ABC): - entity: DatasourceProviderEntityWithPlugin + entity: DatasourceProviderEntityWithPlugin | None tenant_id: str - def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None: self.entity = entity self.tenant_id = tenant_id diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index 45f4777f44..82da10d663 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -24,5 +24,5 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: return DatasourceProviderType.LOCAL_FILE diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index b809ad6abf..f94031656e 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -69,5 +69,5 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: return DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index b1b6489197..e8256b3282 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -49,5 +49,5 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: return DatasourceProviderType.WEBSITE_CRAWL diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 95f05fcee0..d9043702d2 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -10,7 +10,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon plugin_unique_identifier: str def __init__( - self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + self, entity: DatasourceProviderEntityWithPlugin | None, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: super().__init__(entity, tenant_id) self.plugin_id = plugin_id diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index aa8b1ad4d6..645e067e4c 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -6,7 +6,7 @@ from core.datasource.entities.datasource_entities import ( GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + GetWebsiteCrawlResponse, DatasourceProviderEntity, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -17,7 +17,7 @@ from core.plugin.impl.base import BasePluginClient class PluginDatasourceManager(BasePluginClient): - def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]: + def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: """ Fetch datasource providers for the given tenant. """ @@ -46,12 +46,15 @@ class PluginDatasourceManager(BasePluginClient): # for datasource in provider.declaration.datasources: # datasource.identity.provider = provider.declaration.identity.name - return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())] + return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())] def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ Fetch datasource provider for the given tenant and plugin. """ + if provider == "langgenius/file/file": + return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + tool_provider_id = ToolProviderID(provider) def transformer(json_response: dict[str, Any]) -> dict: @@ -218,6 +221,7 @@ class PluginDatasourceManager(BasePluginClient): "X-Plugin-ID": tool_provider_id.plugin_id, "Content-Type": "application/json", }, + ) for resp in response: @@ -228,27 +232,48 @@ class PluginDatasourceManager(BasePluginClient): def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", - "author": "langgenius", - "name": "langgenius/file/file", "plugin_id": "langgenius/file", + "provider": "langgenius", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", - "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, - "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", - "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, - "type": "datasource", - "team_credentials": {}, - "is_team_authorization": False, - "allow_delete": True, - "datasources": [ - { + "declaration": { + "identity": { "author": "langgenius", - "name": "upload_file", - "label": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File"}, - "description": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File."}, + "name": "langgenius/file/file", + "label": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", + "description": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + } + }, + "credentials_schema": [], + "provider_type": "local_file", + "datasources": [{ + "identity": { + "author": "langgenius", + "name": "local_file", + "provider": "langgenius", + "label": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + } + }, "parameters": [], - "labels": ["search"], - "output_schema": None, - } - ], - "labels": ["search"], + "description": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + } + }] + } } diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py index cee9e5a895..f6ec44cb77 100644 --- a/api/core/workflow/nodes/datasource/__init__.py +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -1,3 +1,3 @@ -from .tool_node import ToolNode +from .datasource_node import DatasourceNode __all__ = ["DatasourceNode"] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index f5e34f5998..198e167341 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -40,14 +40,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): node_data = cast(DatasourceNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool - + datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + if not datasource_type: + raise DatasourceNodeError("Datasource type is not set") + datasource_type = datasource_type.value + datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + if not datasource_info: + raise DatasourceNodeError("Datasource info is not set") + datasource_info = datasource_info.value # get datasource runtime try: from core.datasource.datasource_manager import DatasourceManager - datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) - datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") @@ -84,47 +89,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) try: - if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPageContentResponse = ( - datasource_runtime._get_online_document_page_content( - user_id=self.user_id, - datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), - provider_type=datasource_runtime.datasource_provider_type(), + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPageContentResponse = ( + datasource_runtime._get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), + provider_type=datasource_type, + ) ) - ) - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "online_document": online_document_result.result.model_dump(), - "datasource_type": datasource_runtime.datasource_provider_type, - }, + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "online_document": online_document_result.result.model_dump(), + "datasource_type": datasource_type, + }, + ) ) - ) - elif ( - datasource_runtime.datasource_provider_type in ( - DatasourceProviderType.WEBSITE_CRAWL, - DatasourceProviderType.LOCAL_FILE, - ) - ): - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "website": datasource_info, - "datasource_type": datasource_runtime.datasource_provider_type, - }, + case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "website": datasource_info, + "datasource_type": datasource_type, + }, + ) + ) + case DatasourceProviderType.LOCAL_FILE: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file": datasource_info, + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) + ) + case _: + raise DatasourceNodeError( + f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" ) - ) - else: - raise DatasourceNodeError( - f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" - ) except PluginDaemonClientSideError as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -170,23 +183,24 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} result: dict[str, Any] = {} - for parameter_name in node_data.datasource_parameters: - parameter = datasource_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - datasource_input = node_data.datasource_parameters[parameter_name] - if datasource_input.type == "variable": - variable = variable_pool.get(datasource_input.value) - if variable is None: - raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") - parameter_value = variable.value - elif datasource_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(datasource_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") - result[parameter_name] = parameter_value + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + parameter = datasource_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + datasource_input = node_data.datasource_parameters[parameter_name] + if datasource_input.type == "variable": + variable = variable_pool.get(datasource_input.value) + if variable is None: + raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") + parameter_value = variable.value + elif datasource_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(datasource_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") + result[parameter_name] = parameter_value return result diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 68aa9fa34c..212184bb81 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Union +from typing import Any, Literal, Union, Optional from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo @@ -9,30 +9,17 @@ from core.workflow.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): provider_id: str provider_name: str # redundancy - datasource_name: str - tool_label: str # redundancy - datasource_configurations: dict[str, Any] + provider_type: str + datasource_name: Optional[str] = "local_file" + datasource_configurations: dict[str, Any] | None = None plugin_unique_identifier: str | None = None # redundancy - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceInput(BaseModel): # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] + value: Optional[Union[Any, list[str]]] = None + type: Optional[Literal["mixed", "variable", "constant"]] = None @field_validator("type", mode="before") @classmethod @@ -51,4 +38,4 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity): raise ValueError("value must be a string, int, float, or bool") return typ - datasource_parameters: dict[str, DatasourceInput] + datasource_parameters: dict[str, DatasourceInput] | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index dac541621a..803ecc765f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -19,6 +19,7 @@ from .entities import KnowledgeIndexNodeData from .exc import ( KnowledgeIndexNodeError, ) +from ..base import BaseNode logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ default_retrieval_model = { } -class KnowledgeIndexNode(LLMNode): +class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): _node_data_cls = KnowledgeIndexNodeData # type: ignore _node_type = NodeType.KNOWLEDGE_INDEX @@ -44,7 +45,7 @@ class KnowledgeIndexNode(LLMNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - error="Query variable is not object type.", + error="Index chunk variable is not object type.", ) chunks = variable.value variables = {"chunks": chunks} diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 1f1be59542..e328c20096 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code import CodeNode +from core.workflow.nodes.datasource.datasource_node import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode from core.workflow.nodes.end import EndNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request import HttpRequestNode from core.workflow.nodes.if_else import IfElseNode from core.workflow.nodes.iteration import IterationNode, IterationStartNode +from core.workflow.nodes.knowledge_index import KnowledgeIndexNode from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.list_operator import ListOperatorNode from core.workflow.nodes.llm import LLMNode @@ -119,4 +121,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { LATEST_VERSION: AgentNode, "1": AgentNode, }, + NodeType.DATASOURCE: { + LATEST_VERSION: DatasourceNode, + "1": DatasourceNode, + }, + NodeType.KNOWLEDGE_INDEX: { + LATEST_VERSION: KnowledgeIndexNode, + "1": KnowledgeIndexNode, + }, }