diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 912981db01..d2136f771b 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -41,7 +41,8 @@ class DatasourcePluginOauthApi(Resource): if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() - redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" + redirect_url = (f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?" + f"provider={provider}&plugin_id={plugin_id}") system_credentials = plugin_oauth_config.system_credentials if system_credentials: system_credentials["redirect_url"] = redirect_url diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c97b3b1d92..616803247c 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from core.app.apps.pipeline.pipeline_generator import PipelineGenerator import services from configs import dify_config from controllers.console import api @@ -24,6 +23,7 @@ from controllers.console.wraps import ( ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db @@ -302,87 +302,87 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): - """ - Run rag pipeline datasource - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - args = parser.parse_args() - - job_id = args.get("job_id") - if job_id == None: - raise ValueError("missing job_id") - datasource_type = args.get("datasource_type") - if datasource_type == None: - raise ValueError("missing datasource_type") - - rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node_status( - pipeline=pipeline, - node_id=node_id, - job_id=job_id, - account=current_user, - datasource_type=datasource_type, - is_published=True - ) - - return result +# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): +# @setup_required +# @login_required +# @account_initialization_required +# @get_rag_pipeline +# def post(self, pipeline: Pipeline, node_id: str): +# """ +# Run rag pipeline datasource +# """ +# # The role of the current user in the ta table must be admin, owner, or editor +# if not current_user.is_editor: +# raise Forbidden() +# +# if not isinstance(current_user, Account): +# raise Forbidden() +# +# parser = reqparse.RequestParser() +# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") +# parser.add_argument("datasource_type", type=str, required=True, location="json") +# args = parser.parse_args() +# +# job_id = args.get("job_id") +# if job_id == None: +# raise ValueError("missing job_id") +# datasource_type = args.get("datasource_type") +# if datasource_type == None: +# raise ValueError("missing datasource_type") +# +# rag_pipeline_service = RagPipelineService() +# result = rag_pipeline_service.run_datasource_workflow_node_status( +# pipeline=pipeline, +# node_id=node_id, +# job_id=job_id, +# account=current_user, +# datasource_type=datasource_type, +# is_published=True +# ) +# +# return result -class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): - """ - Run rag pipeline datasource - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - args = parser.parse_args() - - job_id = args.get("job_id") - if job_id == None: - raise ValueError("missing job_id") - datasource_type = args.get("datasource_type") - if datasource_type == None: - raise ValueError("missing datasource_type") - - rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node_status( - pipeline=pipeline, - node_id=node_id, - job_id=job_id, - account=current_user, - datasource_type=datasource_type, - is_published=False - ) - - return result - +# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): +# @setup_required +# @login_required +# @account_initialization_required +# @get_rag_pipeline +# def post(self, pipeline: Pipeline, node_id: str): +# """ +# Run rag pipeline datasource +# """ +# # The role of the current user in the ta table must be admin, owner, or editor +# if not current_user.is_editor: +# raise Forbidden() +# +# if not isinstance(current_user, Account): +# raise Forbidden() +# +# parser = reqparse.RequestParser() +# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") +# parser.add_argument("datasource_type", type=str, required=True, location="json") +# args = parser.parse_args() +# +# job_id = args.get("job_id") +# if job_id == None: +# raise ValueError("missing job_id") +# datasource_type = args.get("datasource_type") +# if datasource_type == None: +# raise ValueError("missing datasource_type") +# +# rag_pipeline_service = RagPipelineService() +# result = rag_pipeline_service.run_datasource_workflow_node_status( +# pipeline=pipeline, +# node_id=node_id, +# job_id=job_id, +# account=current_user, +# datasource_type=datasource_type, +# is_published=False +# ) +# +# return result +# class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @@ -425,7 +425,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): return result -class RagPipelineDrafDatasourceNodeRunApi(Resource): +class RagPipelineDraftDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required @@ -447,22 +447,28 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): args = parser.parse_args() inputs = args.get("inputs") - if inputs == None: + if inputs is None: raise ValueError("missing inputs") datasource_type = args.get("datasource_type") - if datasource_type == None: + if datasource_type is None: raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=False - ) - ) + try: + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False + ) + ) + ) + except Exception as e: + print(e) class RagPipelinePublishedNodeRunApi(Resource): @@ -981,17 +987,17 @@ api.add_resource( RagPipelinePublishedDatasourceNodeRunApi, "/rag/pipelines//workflows/published/datasource/nodes//run", ) -api.add_resource( - RagPipelinePublishedDatasourceNodeRunStatusApi, - "/rag/pipelines//workflows/published/datasource/nodes//run-status", -) -api.add_resource( - RagPipelineDraftDatasourceNodeRunStatusApi, - "/rag/pipelines//workflows/draft/datasource/nodes//run-status", -) +# api.add_resource( +# RagPipelinePublishedDatasourceNodeRunStatusApi, +# "/rag/pipelines//workflows/published/datasource/nodes//run-status", +# ) +# api.add_resource( +# RagPipelineDraftDatasourceNodeRunStatusApi, +# "/rag/pipelines//workflows/draft/datasource/nodes//run-status", +# ) api.add_resource( - RagPipelineDrafDatasourceNodeRunApi, + RagPipelineDraftDatasourceNodeRunApi, "/rag/pipelines//workflows/draft/datasource/nodes//run", ) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 99bb597721..f170d0ee3f 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -186,7 +186,10 @@ class WorkflowResponseConverter: elif event.node_type == NodeType.DATASOURCE: node_data = cast(DatasourceNodeData, event.node_data) manager = PluginDatasourceManager() - provider_entity = manager.fetch_datasource_provider(self._application_generate_entity.app_config.tenant_id, f"{node_data.plugin_id}/{node_data.provider_name}") + provider_entity = manager.fetch_datasource_provider( + self._application_generate_entity.app_config.tenant_id, + f"{node_data.plugin_id}/{node_data.provider_name}" + ) response.data.extras["icon"] = provider_entity.declaration.identity.icon return response diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 6704d4e73a..51296b64d2 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -14,7 +14,7 @@ from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import MessageFile, UploadFile from models.tools import ToolFile @@ -86,7 +86,7 @@ class DatasourceFileManager: size=len(file_binary), extension=extension, mime_type=mimetype, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_by=user_id, used=False, hash=hashlib.sha3_256(file_binary).hexdigest(), @@ -133,7 +133,7 @@ class DatasourceFileManager: size=len(blob), extension=extension, mime_type=mimetype, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_by=user_id, used=False, hash=hashlib.sha3_256(blob).hexdigest(), @@ -239,6 +239,6 @@ class DatasourceFileManager: # init tool_file_parser -from core.file.datasource_file_parser import datasource_file_manager - -datasource_file_manager["manager"] = DatasourceFileManager +# from core.file.datasource_file_parser import datasource_file_manager +# +# datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index b9a0c1f150..9b72966b50 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -298,28 +298,3 @@ class WebsiteCrawlMessage(BaseModel): class DatasourceMessage(ToolInvokeMessage): pass - -class DatasourceInvokeMessage(ToolInvokeMessage): - """ - Datasource Invoke Message. - """ - - class WebsiteCrawlMessage(BaseModel): - """ - Website crawl message - """ - - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") - web_info_list: Optional[list[WebSiteInfoDetail]] = [] - - class OnlineDocumentMessage(BaseModel): - """ - Online document message - """ - - workspace_id: str = Field(..., description="The workspace id") - workspace_name: str = Field(..., description="The workspace name") - workspace_icon: str = Field(..., description="The workspace icon") - total: int = Field(..., description="The total number of documents") - pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index db73d9a64b..c1e015fd3a 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -5,7 +5,7 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, - DatasourceInvokeMessage, + DatasourceMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, @@ -33,7 +33,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def _get_online_document_pages( + def get_online_document_pages( self, user_id: str, datasource_parameters: Mapping[str, Any], @@ -51,12 +51,12 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def _get_online_document_page_content( + def get_online_document_page_content( self, user_id: str, datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[DatasourceMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_page_content( diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index bd99387e8d..9bc57235d8 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,7 +4,7 @@ from mimetypes import guess_extension from typing import Optional from core.datasource.datasource_file_manager import DatasourceFileManager -from core.datasource.entities.datasource_entities import DatasourceInvokeMessage +from core.datasource.entities.datasource_entities import DatasourceMessage from core.file import File, FileTransferMethod, FileType logger = logging.getLogger(__name__) @@ -14,23 +14,23 @@ class DatasourceFileMessageTransformer: @classmethod def transform_datasource_invoke_messages( cls, - messages: Generator[DatasourceInvokeMessage, None, None], + messages: Generator[DatasourceMessage, None, None], user_id: str, tenant_id: str, conversation_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[DatasourceMessage, None, None]: """ Transform datasource message and handle file download """ for message in messages: - if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}: + if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}: yield message - elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance( - message.message, DatasourceInvokeMessage.TextMessage + elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance( + message.message, DatasourceMessage.TextMessage ): # try to download image try: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceMessage.TextMessage) file = DatasourceFileManager.create_file_by_url( user_id=user_id, @@ -41,20 +41,20 @@ class DatasourceFileMessageTransformer: url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=message.meta.copy() if message.meta is not None else {}, ) except Exception as e: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.TEXT, - message=DatasourceInvokeMessage.TextMessage( + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage( text=f"Failed to download image: {message.message.text}: {e}" ), meta=message.meta.copy() if message.meta is not None else {}, ) - elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceMessage.MessageType.BLOB: # get mime type and save blob to storage meta = message.meta or {} @@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer: filename = meta.get("file_name", None) # if message is str, encode it to bytes - if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage): + if not isinstance(message.message, DatasourceMessage.BlobMessage): raise ValueError("unexpected message type") # FIXME: should do a type check here. @@ -81,18 +81,18 @@ class DatasourceFileMessageTransformer: # check if file is image if "image" in mimetype: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.BINARY_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.BINARY_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) - elif message.type == DatasourceInvokeMessage.MessageType.FILE: + elif message.type == DatasourceMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): @@ -100,15 +100,15 @@ class DatasourceFileMessageTransformer: assert file.related_id is not None url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 1625670165..d0e442f31a 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -5,7 +5,6 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, - DatasourceInvokeMessage, DatasourceProviderType, WebsiteCrawlMessage, ) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 06ee00c688..66469b43b4 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -2,7 +2,7 @@ from collections.abc import Generator, Mapping from typing import Any from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, + DatasourceMessage, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, WebsiteCrawlMessage, @@ -164,7 +164,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[DatasourceMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -174,7 +174,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", - DatasourceInvokeMessage, + DatasourceMessage, data={ "user_id": user_id, "data": { diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 34a86555f7..03047c0545 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -188,8 +188,6 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" - WEBSITE_CRAWL = "website_crawl" - ONLINE_DOCUMENT = "online_document" type: MessageType = MessageType.TEXT """ diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 0d8a4ee821..fbf591eb8f 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -277,4 +277,8 @@ InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | Bas class DatasourceRunEvent(BaseModel): status: str = Field(..., description="status") - result: dict[str, Any] = Field(..., description="result") + data: Mapping[str,Any] | list = Field(..., description="result") + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") + time_consuming: Optional[float] = Field(..., description="time consuming") + diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 240eeeb725..0e3decc7b4 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, + DatasourceMessage, DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, @@ -100,8 +100,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[DatasourceInvokeMessage, None, None] = ( - datasource_runtime._get_online_document_page_content( + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), provider_type=datasource_type, @@ -290,7 +290,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): def _transform_message( self, - messages: Generator[DatasourceInvokeMessage, None, None], + messages: Generator[DatasourceMessage, None, None], parameters_for_log: dict[str, Any], datasource_info: dict[str, Any], ) -> Generator: @@ -313,11 +313,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): for message in message_stream: if message.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.BINARY_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, + DatasourceMessage.MessageType.IMAGE_LINK, + DatasourceMessage.MessageType.BINARY_LINK, + DatasourceMessage.MessageType.IMAGE, }: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceMessage.TextMessage) url = message.message.text if message.meta: @@ -344,9 +344,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) files.append(file) - elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceMessage.TextMessage) assert message.meta datasource_file_id = message.message.text.split("/")[-1].split(".")[0] @@ -367,14 +367,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) ) - elif message.type == DatasourceInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + elif message.type == DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) text += message.message.text yield RunStreamChunkEvent( chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] ) - elif message.type == DatasourceInvokeMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) + elif message.type == DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) if self.node_type == NodeType.AGENT: msg_metadata = message.message.json_object.pop("execution_metadata", {}) agent_execution_metadata = { @@ -383,13 +383,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): if key in WorkflowNodeExecutionMetadataKey.__members__.values() } json.append(message.message.json_object) - elif message.type == DatasourceInvokeMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + elif message.type == DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) - elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) + elif message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -404,7 +404,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == DatasourceInvokeMessage.MessageType.FILE: + elif message.type == DatasourceMessage.MessageType.FILE: assert message.meta is not None files.append(message.meta["file"]) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 9a37f0e51c..aaecc7b989 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]): inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", - error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea + error_type=type(e).__name__, ) ) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 64fa97197d..80e903bd46 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -32,7 +32,11 @@ class DatasourceProviderService: :param credentials: """ credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials ) if credential_valid: # Get all provider configurations of the current workspace @@ -104,7 +108,8 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}") # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -144,7 +149,8 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}") # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -161,7 +167,12 @@ class DatasourceProviderService: return copy_credentials_list - def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: + def update_datasource_credentials(self, + tenant_id: str, + auth_id: str, + provider: str, + plugin_id: str, + credentials: dict) -> None: """ update datasource credentials. """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1d61677bea..a5f2135100 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -15,7 +15,6 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, DatasourceProviderType, OnlineDocumentPagesMessage, WebsiteCrawlMessage, @@ -423,70 +422,71 @@ class RagPipelineService: return workflow_node_execution - def run_datasource_workflow_node_status( - self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool - ) -> dict: - """ - Run published workflow datasource - """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") - - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - - from core.datasource.datasource_manager import DatasourceManager - - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get('provider_name'), - plugin_id=datasource_node_data.get('plugin_id'), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_results: list[WebsiteCrawlMessage] = [] - for website_message in datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters={"job_id": job_id}, - provider_type=datasource_runtime.datasource_provider_type(), - ): - website_crawl_results.append(website_message) - return { - "result": [result for result in website_crawl_results.result], - "status": website_crawl_results.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + # def run_datasource_workflow_node_status( + # self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, + # datasource_type: str, is_published: bool + # ) -> dict: + # """ + # Run published workflow datasource + # """ + # if is_published: + # # fetch published workflow by app_model + # workflow = self.get_published_workflow(pipeline=pipeline) + # else: + # workflow = self.get_draft_workflow(pipeline=pipeline) + # if not workflow: + # raise ValueError("Workflow not initialized") + # + # # run draft workflow node + # datasource_node_data = None + # start_at = time.perf_counter() + # datasource_nodes = workflow.graph_dict.get("nodes", []) + # for datasource_node in datasource_nodes: + # if datasource_node.get("id") == node_id: + # datasource_node_data = datasource_node.get("data", {}) + # break + # if not datasource_node_data: + # raise ValueError("Datasource node data not found") + # + # from core.datasource.datasource_manager import DatasourceManager + # + # datasource_runtime = DatasourceManager.get_datasource_runtime( + # provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + # datasource_name=datasource_node_data.get("datasource_name"), + # tenant_id=pipeline.tenant_id, + # datasource_type=DatasourceProviderType(datasource_type), + # ) + # datasource_provider_service = DatasourceProviderService() + # credentials = datasource_provider_service.get_real_datasource_credentials( + # tenant_id=pipeline.tenant_id, + # provider=datasource_node_data.get('provider_name'), + # plugin_id=datasource_node_data.get('plugin_id'), + # ) + # if credentials: + # datasource_runtime.runtime.credentials = credentials[0].get("credentials") + # match datasource_type: + # + # case DatasourceProviderType.WEBSITE_CRAWL: + # datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + # website_crawl_results: list[WebsiteCrawlMessage] = [] + # for website_message in datasource_runtime.get_website_crawl( + # user_id=account.id, + # datasource_parameters={"job_id": job_id}, + # provider_type=datasource_runtime.datasource_provider_type(), + # ): + # website_crawl_results.append(website_message) + # return { + # "result": [result for result in website_crawl_results.result], + # "status": website_crawl_results.result.status, + # "provider_type": datasource_node_data.get("provider_type"), + # } + # case _: + # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> Generator[DatasourceRunEvent, None, None]: + ) -> Generator[str, None, None]: """ Run published workflow datasource """ @@ -533,25 +533,40 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - for message in online_document_result: - yield DatasourceRunEvent( - status="success", - result=message.model_dump(), + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] =\ + datasource_runtime.get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), ) + start_time = time.time() + for message in online_document_result: + end_time = time.time() + online_document_event = DatasourceRunEvent( + status="completed", + data=message.result, + time_consuming=round(end_time - start_time, 2) + ) + yield json.dumps(online_document_event.model_dump()) case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl( + website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - yield from website_crawl_result + start_time = time.time() + for message in website_crawl_result: + end_time = time.time() + crawl_event = DatasourceRunEvent( + status=message.result.status, + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming = round(end_time - start_time, 2) + ) + yield json.dumps(crawl_event.model_dump()) case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") @@ -952,7 +967,9 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") - max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() + max_position = db.session.query( + func.max(PipelineCustomizedTemplate.position)).filter( + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)