From b2b95412b9ab7c94fc9b9a36ebe4bc089133e503 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 13 Jun 2025 15:04:22 +0800 Subject: [PATCH] r2 --- .../entities/datasource_entities.py | 28 ++- .../online_document/online_document_plugin.py | 7 +- .../datasource/utils/message_transformer.py | 6 +- api/core/file/enums.py | 1 + api/core/plugin/impl/datasource.py | 30 ++- .../nodes/datasource/datasource_node.py | 180 ++++++++++++++++-- api/factories/file_factory.py | 47 +++++ 7 files changed, 253 insertions(+), 46 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index adcdcccf83..8d68c80c81 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum +from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage class DatasourceProviderType(enum.StrEnum): @@ -301,3 +301,29 @@ class GetWebsiteCrawlResponse(BaseModel): """ result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) + + +class DatasourceInvokeMessage(ToolInvokeMessage): + """ + Datasource Invoke Message. + """ + + class WebsiteCrawlMessage(BaseModel): + """ + Website crawl message + """ + + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + + class OnlineDocumentMessage(BaseModel): + """ + Online document message + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") \ No newline at end of file diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index f94031656e..2ab60cae1e 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Generator from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, @@ -38,7 +39,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetOnlineDocumentPagesResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_pages( @@ -56,7 +57,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> GetOnlineDocumentPageContentResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_page_content( diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index a10030d93b..bd99387e8d 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -39,7 +39,7 @@ class DatasourceFileMessageTransformer: conversation_id=conversation_id, ) - url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" + url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" yield DatasourceInvokeMessage( type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, @@ -77,7 +77,7 @@ class DatasourceFileMessageTransformer: filename=filename, ) - url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type)) # check if file is image if "image" in mimetype: @@ -98,7 +98,7 @@ class DatasourceFileMessageTransformer: if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: yield DatasourceInvokeMessage( type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/file/enums.py b/api/core/file/enums.py index a50a651dd3..170eb4fc23 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" + DATASOURCE_FILE = "datasource_file" @staticmethod def value_of(value): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 98ee0bb11e..83b1a5760b 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,7 +1,8 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Generator from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, @@ -93,7 +94,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -103,7 +104,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", - GetWebsiteCrawlResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -118,10 +119,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def get_online_document_pages( self, @@ -132,7 +130,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetOnlineDocumentPagesResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -142,7 +140,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", - GetOnlineDocumentPagesResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -157,10 +155,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def get_online_document_page_content( self, @@ -171,7 +166,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> GetOnlineDocumentPageContentResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -181,7 +176,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", - GetOnlineDocumentPageContentResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -196,10 +191,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 2782f2fb4c..bd4a6e3a56 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,13 +1,18 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Generator, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.file import File from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError @@ -19,8 +24,11 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.tool.exc import ToolFileError from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from models.model import UploadFile from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey @@ -36,7 +44,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): _node_data_cls = DatasourceNodeData _node_type = NodeType.DATASOURCE - def _run(self) -> NodeRunResult: + def _run(self) -> Generator: """ Run the datasource node """ @@ -65,13 +73,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_type=DatasourceProviderType.value_of(datasource_type), ) except DatasourceNodeError as e: - return NodeRunResult( + yield RunCompletedEvent( + run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to get datasource runtime: {str(e)}", error_type=type(e).__name__, ) + ) # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -91,25 +101,22 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPageContentResponse = ( + online_document_result: Generator[DatasourceInvokeMessage, None, None] = ( datasource_runtime._get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), provider_type=datasource_type, ) ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - **online_document_result.result.model_dump(), - "datasource_type": datasource_type, - }, + yield from self._transform_message( + messages=online_document_result, + parameters_for_log=parameters_for_log, + datasource_info=datasource_info, ) + case DatasourceProviderType.WEBSITE_CRAWL: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -117,7 +124,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): **datasource_info, "datasource_type": datasource_type, }, - ) + )) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: @@ -149,7 +156,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable_key_list=new_key_list, variable_value=value, ) - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -157,25 +164,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): "file_info": datasource_info, "datasource_type": datasource_type, }, - ) + )) case _: raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, - ) + )) except DatasourceNodeError as e: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, - ) + )) def _generate_parameters( self, @@ -279,3 +286,136 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result + + + + def _transform_message( + self, + messages: Generator[DatasourceInvokeMessage, None, None], + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + DatasourceInvokeMessage.MessageType.IMAGE_LINK, + DatasourceInvokeMessage.MessageType.BINARY_LINK, + DatasourceInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE) + else: + transfer_method = FileTransferMethod.DATASOURCE_FILE + + datasource_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + + mapping = { + "datasource_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert message.meta + + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "datasource_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.DATASOURCE_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + elif message.type == DatasourceInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent( + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + ) + elif message.type == DatasourceInvokeMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) + if self.node_type == NodeType.AGENT: + msg_metadata = message.message.json_object.pop("execution_metadata", {}) + agent_execution_metadata = { + key: value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + json.append(message.message.json_object) + elif message.type == DatasourceInvokeMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == DatasourceInvokeMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"json": json, "files": files, **variables, "text": text}, + metadata={ + WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info, + }, + inputs=parameters_for_log, + ) + ) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 52f119936f..128041a27d 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -60,6 +60,7 @@ def build_from_mapping( FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.REMOTE_URL: _build_from_remote_url, FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, } build_func = build_functions.get(transfer_method) @@ -302,6 +303,52 @@ def _build_from_tool_file( ) +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, +) -> File: + datasource_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.tenant_id == tenant_id, + ) + .first() + ) + + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + + detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) + + return File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=datasource_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=datasource_file.source_url, + related_id=datasource_file.id, + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + ) + def _is_file_valid_with_config( *, input_file_type: str,