diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3f30ceab43..7b53983c65 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -145,7 +145,7 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id): + def get(self, dataset_id: str): dataset_id = str(dataset_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) @@ -153,7 +153,7 @@ class DatasetDocumentListApi(Resource): sort = request.args.get("sort", default="-created_at", type=str) # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch_val = request.args.get("fetch", default="false") + fetch_val = request.args.get("fetch", default=False) if isinstance(fetch_val, bool): fetch = fetch_val else: @@ -250,7 +250,7 @@ class DatasetDocumentListApi(Resource): @marshal_with(dataset_and_document_fields) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id): + def post(self, dataset_id: str): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -308,7 +308,7 @@ class DatasetDocumentListApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def delete(self, dataset_id): + def delete(self, dataset_id: str): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 964de0a863..282159415e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -119,9 +119,6 @@ class DraftRagPipelineApi(Resource): else: abort(415) - if not isinstance(current_user, Account): - raise Forbidden() - try: environment_variables_list = args.get("environment_variables") or [] environment_variables = [ @@ -164,9 +161,6 @@ class RagPipelineDraftRunIterationNodeApi(Resource): if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -201,9 +195,6 @@ class RagPipelineDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -238,9 +229,6 @@ class DraftRagPipelineRunApi(Resource): if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") @@ -275,9 +263,6 @@ class PublishedRagPipelineRunApi(Resource): if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") @@ -396,10 +381,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - + parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index f676374e5f..fa78943080 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -200,7 +200,7 @@ class DatasetListApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - def get(self, tenant_id): + def get(self, tenant_id: str): """Resource for getting datasets.""" page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e4214a16ad..ff35eb59ff 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -47,3 +47,8 @@ class DatasetInUseError(BaseHTTPException): error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 + +class PipelineRunError(BaseHTTPException): + error_code = "pipeline_run_error" + description = "Pipeline run error." + code = 400 \ No newline at end of file diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 6bfe239964..af73fb272f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -33,6 +33,7 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable @@ -54,7 +55,13 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore +from models.dataset import ( # type: ignore + Dataset, + Document, + Pipeline, + PipelineCustomizedTemplate, + PipelineRecommendedPlugin, +) from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -480,7 +487,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: Mapping[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -1312,7 +1319,7 @@ class RagPipelineService: "uninstalled_recommended_plugins": uninstalled_plugin_list, } - def get_datasource_plugins(self, dataset_id: str, is_published: bool) -> list[dict]: + def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]: """ Get datasource plugins """ @@ -1325,9 +1332,9 @@ class RagPipelineService: workflow: Workflow | None = None if is_published: - workflow: Workflow | None = self.get_published_workflow(pipeline=pipeline) + workflow = self.get_published_workflow(pipeline=pipeline) else: - workflow: Workflow | None = self.get_draft_workflow(pipeline=pipeline) + workflow = self.get_draft_workflow(pipeline=pipeline) if not pipeline or not workflow: raise ValueError("Pipeline or workflow not found") @@ -1338,33 +1345,68 @@ class RagPipelineService: datasource_node_data = datasource_node.get("data", {}) if not datasource_node_data: continue - - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - user_input_variables_keys = [] - user_input_variables = [] - for _, value in datasource_parameters.items(): - if value.get("value") and isinstance(value.get("value"), str): - pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" - match = re.match(pattern, value["value"]) - if match: - full_path = match.group(1) - last_part = full_path.split(".")[-1] + variables = workflow.rag_pipeline_variables + if variables: + variables_map = {item["variable"]: item for item in variables} + else: + variables_map = {} + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + user_input_variables_keys = [] + user_input_variables = [] + + for _, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + user_input_variables_keys.append(last_part) + elif value.get("value") and isinstance(value.get("value"), list): + last_part = value.get("value")[-1] user_input_variables_keys.append(last_part) - elif value.get("value") and isinstance(value.get("value"), list): - last_part = value.get("value")[-1] - user_input_variables_keys.append(last_part) - for key, value in variables_map.items(): - if key in user_input_variables_keys: - user_input_variables.append(value) + for key, value in variables_map.items(): + if key in user_input_variables_keys: + user_input_variables.append(value) + + # get credentials + datasource_provider_service: DatasourceProviderService = DatasourceProviderService() + credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials( + tenant_id=tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id") + ) + credential_info_list: list[Any] = [] + for credential in credentials: + credential_info_list.append({ + "id": credential.get("id"), + "name": credential.get("name"), + "type": credential.get("type"), + "is_default": credential.get("is_default"), + }) datasource_plugins.append({ "node_id": datasource_node.get("id"), "plugin_id": datasource_node_data.get("plugin_id"), "provider_name": datasource_node_data.get("provider_name"), - "provider_type": datasource_node_data.get("provider_type"), + "datasource_type": datasource_node_data.get("provider_type"), "title": datasource_node_data.get("title"), "user_input_variables": user_input_variables, + "credentials": credential_info_list, }) return datasource_plugins + + def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline: + """ + Get pipeline + """ + dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + return pipeline