diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 6676deb63a..1a4e9240b6 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -161,7 +161,7 @@ class CreateEmptyRagPipelineDatasetApi(Resource): args = parser.parse_args() dataset = DatasetService.create_empty_rag_pipeline_dataset( tenant_id=current_user.current_tenant_id, - rag_pipeline_dataset_create_entity=args, + rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args), ) return marshal(dataset, dataset_detail_fields), 201 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index bdd40fcabe..f6238bf143 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from models.model import EndUser import services from configs import dify_config from controllers.console import api @@ -40,6 +39,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline +from models.model import EndUser from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService @@ -242,7 +242,7 @@ class DraftRagPipelineRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") args = parser.parse_args() @@ -320,6 +320,9 @@ class RagPipelineDatasourceNodeRunApi(Resource): inputs = args.get("inputs") if inputs == None: raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.run_datasource_workflow_node( @@ -327,7 +330,7 @@ class RagPipelineDatasourceNodeRunApi(Resource): node_id=node_id, user_inputs=inputs, account=current_user, - datasource_type=args.get("datasource_type"), + datasource_type=datasource_type, ) return result diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index c1aa9747d2..ccc227f3f4 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat from extensions.ext_database import db from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline +from models.model import AppMode from services.dataset_service import DocumentService logger = logging.getLogger(__name__) @@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, @@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator): for datasource_info in datasource_info_list: workflow_run_id = str(uuid.uuid4()) document_id = None + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") if invoke_from == InvokeFrom.PUBLISHED: + position = DocumentService.get_documents_position(pipeline.dataset_id) position = DocumentService.get_documents_position(pipeline.dataset_id) document = self._build_document( tenant_id=pipeline.tenant_id, dataset_id=pipeline.dataset_id, - built_in_field_enabled=pipeline.dataset.built_in_field_enabled, + built_in_field_enabled=dataset.built_in_field_enabled, datasource_type=datasource_type, datasource_info=datasource_info, created_from="rag-pipeline", position=position, account=user, batch=batch, - document_form=pipeline.dataset.chunk_structure, + document_form=dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator): # init application generate entity application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - pipline_config=pipeline_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, datasource_type=datasource_type, datasource_info=datasource_info, - dataset_id=pipeline.dataset_id, + dataset_id=dataset.id, + start_node_id=start_node_id, batch=batch, document_id=document_id, inputs=self._prepare_user_inputs( @@ -160,17 +167,28 @@ class PipelineGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - - return self._generate( - pipeline=pipeline, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - workflow_thread_pool_id=workflow_thread_pool_id, - ) + if invoke_from == InvokeFrom.DEBUGGER: + return self._generate( + pipeline=pipeline, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + else: + self._generate( + pipeline=pipeline, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) def _generate( self, @@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator): task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=pipeline.mode, + app_mode=AppMode.RAG_PIPELINE, ) # new thread @@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( + application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=app_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=args["datasource_type"], + datasource_info=args["datasource_info"], + dataset_id=pipeline.dataset_id, + batch=args["batch"], + document_id=args["document_id"], inputs={}, files=[], user_id=user.id, @@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( - app_model=app_model, + pipeline=pipeline, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, @@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator): def single_loop_generate( self, - app_model: App, + pipeline: Pipeline, workflow: Workflow, node_id: str, user: Account | EndUser, @@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( - app_model=app_model, + pipeline=pipeline, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 1395a47d88..80b724dd20 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -1,5 +1,6 @@ import logging -from typing import Optional, cast +from collections.abc import Mapping +from typing import Any, Optional, cast from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager @@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import ( from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.dataset import Pipeline @@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, SystemVariableKey.BATCH: self.application_generate_entity.batch, SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, + SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, + SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, } variable_pool = VariablePool( @@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner): ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_rag_pipeline_graph( + graph_config=workflow.graph_dict, + start_node_id=self.application_generate_entity.start_node_id, + ) # RUN WORKFLOW workflow_entry = WorkflowEntry( @@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner): # return workflow return workflow + + def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph: + """ + Init pipeline graph + """ + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + nodes = graph_config.get("nodes", []) + edges = graph_config.get("edges", []) + real_run_nodes = [] + real_edges = [] + exclude_node_ids = [] + for node in nodes: + node_id = node.get("id") + node_type = node.get("data", {}).get("type", "") + if node_type == "datasource": + if start_node_id != node_id: + exclude_node_ids.append(node_id) + continue + real_run_nodes.append(node) + for edge in edges: + if edge.get("source") in exclude_node_ids : + continue + real_edges.append(edge) + graph_config = dict(graph_config) + graph_config["nodes"] = real_run_nodes + graph_config["edges"] = real_edges + # init graph + graph = Graph.init(graph_config=graph_config) + + if not graph: + raise ValueError("graph not found in workflow") + + return graph \ No newline at end of file diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index d730704f48..4565d37d5b 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -233,14 +233,14 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): """ RAG Pipeline Application Generate Entity. """ - - # app config - pipline_config: WorkflowUIBasedAppConfig + # pipeline config + pipeline_config: WorkflowUIBasedAppConfig datasource_type: str datasource_info: Mapping[str, Any] dataset_id: str batch: str - document_id: str + document_id: Optional[str] = None + start_node_id: Optional[str] = None class SingleIterationRunEntity(BaseModel): """ diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 34d17c880a..0e210c1389 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -18,3 +18,5 @@ class SystemVariableKey(StrEnum): DOCUMENT_ID = "document_id" BATCH = "batch" DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 8e5b1e7142..7062fc4565 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -121,6 +121,8 @@ class Graph(BaseModel): # fetch nodes that have no predecessor node root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} + + for node_config in node_configs: node_id = node_config.get("id") if not node_id: @@ -140,7 +142,8 @@ class Graph(BaseModel): ( node_config.get("id") for node_config in root_node_configs - if node_config.get("data", {}).get("type", "") == NodeType.START.value + if node_config.get("data", {}).get("type", "") == NodeType.START.value + or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value ), None, ) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 612c5a5a74..f5e34f5998 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -6,11 +6,8 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, - GetWebsiteCrawlRequest, - GetWebsiteCrawlResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.file import File from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment @@ -42,22 +39,23 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): """ node_data = cast(DatasourceNodeData, self.node_data) - - # fetch datasource icon - datasource_info = { - "provider_id": node_data.provider_id, - "plugin_unique_identifier": node_data.plugin_unique_identifier, - } + variable_pool = self.graph_runtime_state.variable_pool # get datasource runtime try: from core.datasource.datasource_manager import DatasourceManager + datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + + datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + if datasource_type is None: + raise DatasourceNodeError("Datasource type is not set") + datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id=node_data.provider_id, datasource_name=node_data.datasource_name, tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType(node_data.provider_type), + datasource_type=DatasourceProviderType(datasource_type), ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -75,12 +73,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_parameters = datasource_runtime.entity.parameters parameters = self._generate_parameters( datasource_parameters=datasource_parameters, - variable_pool=self.graph_runtime_state.variable_pool, + variable_pool=variable_pool, node_data=self.node_data, ) parameters_for_log = self._generate_parameters( datasource_parameters=datasource_parameters, - variable_pool=self.graph_runtime_state.variable_pool, + variable_pool=variable_pool, node_data=self.node_data, for_log=True, ) @@ -106,20 +104,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): }, ) ) - elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=self.user_id, - datasource_parameters=GetWebsiteCrawlRequest(**parameters), - provider_type=datasource_runtime.datasource_provider_type(), + elif ( + datasource_runtime.datasource_provider_type in ( + DatasourceProviderType.WEBSITE_CRAWL, + DatasourceProviderType.LOCAL_FILE, ) + ): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "website": website_crawl_result.result.model_dump(), + "website": datasource_info, "datasource_type": datasource_runtime.datasource_provider_type, }, ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8a87964276..62a16c56ce 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Optional, cast from flask_login import current_user from sqlalchemy import func, select @@ -298,13 +298,14 @@ class DatasetService: description=rag_pipeline_dataset_create_entity.description, permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", - runtime_mode="rag_pipeline", + runtime_mode="rag-pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, ) with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) + account = cast(Account, current_user) rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( - account=current_user, + account=account, import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, dataset=dataset, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 3bee0538ab..08bb10b5d4 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -59,12 +59,12 @@ class RagPipelineService: if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") - return result.get("pipeline_templates") + return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])] else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) - return result.get("pipeline_templates") + return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])] @classmethod def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 19c7d37f6e..acd364f6cd 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: class RagPipelinePendingData(BaseModel): import_mode: str yaml_content: str - name: str | None - description: str | None - icon_type: str | None - icon: str | None - icon_background: str | None pipeline_id: str | None @@ -302,10 +297,6 @@ class RagPipelineDslService: dataset.runtime_mode = "rag_pipeline" dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.index_method.indexing_technique == "high_quality": - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore - knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore - ) dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( @@ -445,10 +436,28 @@ class RagPipelineDslService: dataset.runtime_mode = "rag_pipeline" dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.index_method.indexing_technique == "high_quality": - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore - knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name + == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + DatasetCollectionBinding.model_name + == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + DatasetCollectionBinding.type == "dataset", + ) + .order_by(DatasetCollectionBinding.created_at) + .first() ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type="dataset", + ) + db.session.add(dataset_collection_binding) + db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( @@ -602,7 +611,6 @@ class RagPipelineDslService: rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), unique_hash=unique_hash, account=account, environment_variables=environment_variables,