diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 5a13d17843..50c7249fe4 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from configs import dify_config from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, @@ -10,14 +11,17 @@ from core.datasource.entities.datasource_entities import ( class DatasourcePlugin(ABC): entity: DatasourceEntity runtime: DatasourceRuntime + icon: str def __init__( self, entity: DatasourceEntity, runtime: DatasourceRuntime, + icon: str, ) -> None: self.entity = entity self.runtime = runtime + self.icon = icon @abstractmethod def datasource_provider_type(self) -> str: @@ -30,4 +34,8 @@ class DatasourcePlugin(ABC): return self.__class__( entity=self.entity.model_copy(), runtime=runtime, + icon=self.icon, ) + + def get_icon_url(self, tenant_id: str) -> str: + return f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={self.icon}" # noqa: E501 diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index 82da10d663..070a89cb2f 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -8,7 +8,6 @@ from core.datasource.entities.datasource_entities import ( class LocalFileDatasourcePlugin(DatasourcePlugin): tenant_id: str - icon: str plugin_unique_identifier: str def __init__( @@ -19,10 +18,12 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): icon: str, plugin_unique_identifier: str, ) -> None: - super().__init__(entity, runtime) + super().__init__(entity, runtime, icon) self.tenant_id = tenant_id - self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier def datasource_provider_type(self) -> str: return DatasourceProviderType.LOCAL_FILE + + def get_icon_url(self, tenant_id: str) -> str: + return self.icon diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index c1e015fd3a..98ea15e3fc 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -15,7 +15,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager class OnlineDocumentDatasourcePlugin(DatasourcePlugin): tenant_id: str - icon: str plugin_unique_identifier: str entity: DatasourceEntity runtime: DatasourceRuntime @@ -28,9 +27,8 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): icon: str, plugin_unique_identifier: str, ) -> None: - super().__init__(entity, runtime) + super().__init__(entity, runtime, icon) self.tenant_id = tenant_id - self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier def get_online_document_pages( diff --git a/api/core/datasource/online_drive/online_drive_plugin.py b/api/core/datasource/online_drive/online_drive_plugin.py index f0e3cb38f9..64715226cc 100644 --- a/api/core/datasource/online_drive/online_drive_plugin.py +++ b/api/core/datasource/online_drive/online_drive_plugin.py @@ -15,7 +15,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager class OnlineDriveDatasourcePlugin(DatasourcePlugin): tenant_id: str - icon: str plugin_unique_identifier: str entity: DatasourceEntity runtime: DatasourceRuntime @@ -28,9 +27,8 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin): icon: str, plugin_unique_identifier: str, ) -> None: - super().__init__(entity, runtime) + super().__init__(entity, runtime, icon) self.tenant_id = tenant_id - self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier def online_drive_browse_files( diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index d0e442f31a..087ac65a7a 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -13,7 +13,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): tenant_id: str - icon: str plugin_unique_identifier: str entity: DatasourceEntity runtime: DatasourceRuntime @@ -26,9 +25,8 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): icon: str, plugin_unique_identifier: str, ) -> None: - super().__init__(entity, runtime) + super().__init__(entity, runtime, icon) self.tenant_id = tenant_id - self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier def get_website_crawl( diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 2b3741c543..5dc89c23d7 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -75,14 +75,17 @@ class DatasourceNode(Node): node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) - if not datasource_type: + datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + if not datasource_type_segement: raise DatasourceNodeError("Datasource type is not set") - datasource_type = datasource_type.value - datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) - if not datasource_info: + datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None + datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + if not datasource_info_segement: raise DatasourceNodeError("Datasource info is not set") - datasource_info = datasource_info.value + datasource_info_value = datasource_info_segement.value + if not isinstance(datasource_info_value, dict): + raise DatasourceNodeError("Invalid datasource info format") + datasource_info: dict[str, Any] = datasource_info_value # get datasource runtime try: from core.datasource.datasource_manager import DatasourceManager @@ -96,6 +99,7 @@ class DatasourceNode(Node): tenant_id=self.tenant_id, datasource_type=DatasourceProviderType.value_of(datasource_type), ) + datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) except DatasourceNodeError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( @@ -123,7 +127,7 @@ class DatasourceNode(Node): tenant_id=self.tenant_id, provider=node_data.provider_name, plugin_id=node_data.plugin_id, - credential_id=datasource_info.get("credential_id"), + credential_id=datasource_info.get("credential_id", ""), ) match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: @@ -134,9 +138,9 @@ class DatasourceNode(Node): datasource_runtime.get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest( - workspace_id=datasource_info.get("workspace_id"), - page_id=datasource_info.get("page").get("page_id"), - type=datasource_info.get("page").get("type"), + workspace_id=datasource_info.get("workspace_id", ""), + page_id=datasource_info.get("page", {}).get("page_id", ""), + type=datasource_info.get("page", {}).get("type", ""), ), provider_type=datasource_type, ) @@ -154,7 +158,7 @@ class DatasourceNode(Node): datasource_runtime.online_drive_download_file( user_id=self.user_id, request=OnlineDriveDownloadFileRequest( - id=datasource_info.get("id"), + id=datasource_info.get("id", ""), bucket=datasource_info.get("bucket"), ), provider_type=datasource_type, diff --git a/api/models/workflow.py b/api/models/workflow.py index 10f77a1914..4e6e57e916 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -834,7 +834,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - + elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict: + datasource_info = self.execution_metadata_dict["datasource_info"] + extras["icon"] = datasource_info["icon"] return extras def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 4ed0f6d0a4..14c2e429c3 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -7,7 +7,6 @@ from configs import dify_config from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType -from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval logger = logging.getLogger(__name__)