diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py new file mode 100644 index 0000000000..6676deb63a --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -0,0 +1,170 @@ +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore +from werkzeug.exceptions import Forbidden + +import services +from controllers.console import api +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + setup_required, +) +from fields.dataset_fields import dataset_detail_fields +from libs.login import login_required +from models.dataset import DatasetPermissionEnum +from services.dataset_service import DatasetPermissionService, DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class CreateRagPipelineDatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + + parser.add_argument( + "icon_info", + type=dict, + nullable=True, + required=False, + default={}, + ) + + parser.add_argument( + "permission", + type=str, + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + nullable=True, + required=False, + default=DatasetPermissionEnum.ONLY_ME, + ) + + parser.add_argument( + "partial_member_list", + type=list, + nullable=True, + required=False, + default=[], + ) + + parser.add_argument( + "yaml_content", + type=str, + nullable=False, + required=True, + help="yaml_content is required.", + ) + + args = parser.parse_args() + + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args) + try: + import_info = DatasetService.create_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, + ) + if rag_pipeline_dataset_create_entity.permission == "partial_members": + DatasetPermissionService.update_partial_member_list( + current_user.current_tenant_id, + import_info["dataset_id"], + rag_pipeline_dataset_create_entity.partial_member_list, + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return import_info, 201 + + +class CreateEmptyRagPipelineDatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self): + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + + parser.add_argument( + "icon_info", + type=dict, + nullable=True, + required=False, + default={}, + ) + + parser.add_argument( + "permission", + type=str, + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + nullable=True, + required=False, + default=DatasetPermissionEnum.ONLY_ME, + ) + + parser.add_argument( + "partial_member_list", + type=list, + nullable=True, + required=False, + default=[], + ) + + args = parser.parse_args() + dataset = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=args, + ) + return marshal(dataset, dataset_detail_fields), 201 + + +api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset") +api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py new file mode 100644 index 0000000000..853aef2e09 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -0,0 +1,147 @@ +from typing import cast + +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from extensions.ext_database import db +from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from libs.login import login_required +from models import Account +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + +class RagPipelineImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(pipeline_import_fields) + def post(self): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("mode", type=str, required=True, location="json") + parser.add_argument("yaml_content", type=str, location="json") + parser.add_argument("yaml_url", type=str, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("pipeline_id", type=str, location="json") + args = parser.parse_args() + + # Create service with session + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + # Import app + account = cast(Account, current_user) + result = import_service.import_rag_pipeline( + account=account, + import_mode=args["mode"], + yaml_content=args.get("yaml_content"), + yaml_url=args.get("yaml_url"), + pipeline_id=args.get("pipeline_id"), + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING.value: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +class RagPipelineImportConfirmApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(pipeline_import_fields) + def post(self, import_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + # Create service with session + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + # Confirm import + account = cast(Account, current_user) + result = import_service.confirm_import(import_id=import_id, account=account) + session.commit() + + # Return appropriate status code based on result + if result.status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 + + +class RagPipelineImportCheckDependenciesApi(Resource): + @setup_required + @login_required + @get_rag_pipeline + @account_initialization_required + @marshal_with(pipeline_import_check_dependencies_fields) + def get(self, pipeline: Pipeline): + if not current_user.is_editor: + raise Forbidden() + + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + result = import_service.check_dependencies(pipeline=pipeline) + + return result.model_dump(mode="json"), 200 + + +class RagPipelineExportApi(Resource): + @setup_required + @login_required + @get_rag_pipeline + @account_initialization_required + @marshal_with(pipeline_import_check_dependencies_fields) + def get(self, pipeline: Pipeline): + if not current_user.is_editor: + raise Forbidden() + + # Add include_secret params + parser = reqparse.RequestParser() + parser.add_argument("include_secret", type=bool, default=False, location="args") + args = parser.parse_args() + + with Session(db.engine) as session: + export_service = RagPipelineDslService(session) + result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"]) + + return {"data": result}, 200 + + +# Import Rag Pipeline +api.add_resource( + RagPipelineImportApi, + "/rag/pipelines/imports", +) +api.add_resource( + RagPipelineImportConfirmApi, + "/rag/pipelines/imports//confirm", +) +api.add_resource( + RagPipelineImportCheckDependenciesApi, + "/rag/pipelines/imports//check-dependencies", +) +api.add_resource( + RagPipelineExportApi, + "/rag/pipelines//exports", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index e67b3c0657..99d3b73d33 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import cast from flask import abort, request from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -23,12 +24,18 @@ from controllers.console.wraps import ( from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from fields.workflow_run_fields import workflow_run_node_execution_fields +from fields.workflow_run_fields import ( + workflow_run_detail_fields, + workflow_run_node_execution_fields, + workflow_run_node_execution_list_fields, + workflow_run_pagination_fields, +) from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline @@ -36,6 +43,7 @@ from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) @@ -461,45 +469,6 @@ class DefaultRagPipelineBlockConfigApi(Resource): rag_pipeline_service = RagPipelineService() return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) - -class ConvertToRagPipelineApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def post(self, pipeline: Pipeline): - """ - Convert basic mode of chatbot app to workflow mode - Convert expert mode of chatbot app to workflow mode - Convert Completion App to Workflow App - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - if request.data: - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") - args = parser.parse_args() - else: - args = {} - - # convert to workflow mode - rag_pipeline_service = RagPipelineService() - new_app_model = rag_pipeline_service.convert_to_workflow(pipeline=pipeline, account=current_user, args=args) - - # return app id - return { - "new_app_id": new_app_model.id, - } - - class RagPipelineConfigApi(Resource): """Resource for rag pipeline configuration.""" @@ -674,6 +643,85 @@ class RagPipelineSecondStepApi(Resource): ) +class RagPipelineWorkflowRunListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_pagination_fields) + def get(self, pipeline: Pipeline): + """ + Get workflow run list + """ + parser = reqparse.RequestParser() + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + args = parser.parse_args() + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) + + return result + + +class RagPipelineWorkflowRunDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_detail_fields) + def get(self, pipeline: Pipeline, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + rag_pipeline_service = RagPipelineService() + workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id) + + return workflow_run + + +class RagPipelineWorkflowRunNodeExecutionListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_node_execution_list_fields) + def get(self, pipeline: Pipeline, run_id): + """ + Get workflow run node execution list + """ + run_id = str(run_id) + + rag_pipeline_service = RagPipelineService() + node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( + pipeline=pipeline, + run_id=run_id, + ) + + return {"data": node_executions} + + +class DatasourceListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user = current_user + + tenant_id = user.current_tenant_id + + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_rag_pipeline_datasources( + tenant_id, + ) + ] + ) + + api.add_resource( DraftRagPipelineApi, "/rag/pipelines//workflows/draft", @@ -694,10 +742,10 @@ api.add_resource( RagPipelineDraftNodeRunApi, "/rag/pipelines//workflows/draft/nodes//run", ) -api.add_resource( - RagPipelinePublishedNodeRunApi, - "/rag/pipelines//workflows/published/nodes//run", -) +# api.add_resource( +# RagPipelinePublishedNodeRunApi, +# "/rag/pipelines//workflows/published/nodes//run", +# ) api.add_resource( RagPipelineDraftRunIterationNodeApi, @@ -724,11 +772,24 @@ api.add_resource( DefaultRagPipelineBlockConfigApi, "/rag/pipelines//workflows/default-workflow-block-configs/", ) -api.add_resource( - ConvertToRagPipelineApi, - "/rag/pipelines//convert-to-workflow", -) + api.add_resource( RagPipelineByIdApi, "/rag/pipelines//workflows/", ) +api.add_resource( + RagPipelineWorkflowRunListApi, + "/rag/pipelines//workflow-runs", +) +api.add_resource( + RagPipelineWorkflowRunDetailApi, + "/rag/pipelines//workflow-runs/", +) +api.add_resource( + RagPipelineWorkflowRunNodeExecutionListApi, + "/rag/pipelines//workflow-runs//node-executions", +) +api.add_resource( + DatasourceListApi, + "/rag/pipelines/datasources", +) diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 991bceb422..86bd66a3f9 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -2,13 +2,13 @@ from collections.abc import Generator from typing import Any, Optional from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, ) -from core.plugin.manager.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -44,7 +44,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = PluginDatasourceManager() + manager = DatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) @@ -64,7 +64,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = PluginDatasourceManager() + manager = DatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index e9efb7b9dc..ef3382b948 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -4,12 +4,11 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.entities.provider_entities import ProviderConfig -from core.plugin.manager.tool import PluginToolManager -from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.plugin.impl.tool import PluginToolManager from core.tools.errors import ToolProviderCredentialValidationError -class DatasourcePluginProviderController(BuiltinToolProviderController): +class DatasourcePluginProviderController: entity: DatasourceProviderEntityWithPlugin tenant_id: str plugin_id: str @@ -32,12 +31,21 @@ class DatasourcePluginProviderController(BuiltinToolProviderController): """ return DatasourceProviderType.RAG_PIPELINE + @property + def need_credentials(self) -> bool: + """ + returns whether the provider needs credentials + + :return: whether the provider needs credentials + """ + return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider """ manager = PluginToolManager() - if not manager.validate_provider_credentials( + if not manager.validate_datasource_credentials( tenant_id=self.tenant_id, user_id=user_id, provider=self.entity.identity.name, @@ -69,7 +77,7 @@ class DatasourcePluginProviderController(BuiltinToolProviderController): plugin_unique_identifier=self.plugin_unique_identifier, ) - def get_datasources(self) -> list[DatasourceTool]: # type: ignore + def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore """ get all datasources """ diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py new file mode 100644 index 0000000000..2d42484a30 --- /dev/null +++ b/api/core/datasource/entities/api_entities.py @@ -0,0 +1,73 @@ +from typing import Literal, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.datasource.entities.datasource_entities import DatasourceParameter +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool import ToolParameter +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType + + +class DatasourceApiEntity(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[list[ToolParameter]] = None + labels: list[str] = Field(default_factory=list) + output_schema: Optional[dict] = None + + +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] + + +class DatasourceProviderApiEntity(BaseModel): + id: str + author: str + name: str # identifier + description: I18nObject + icon: str | dict + label: I18nObject # label + type: ToolProviderType + masked_credentials: Optional[dict] = None + original_credentials: Optional[dict] = None + is_team_authorization: bool = False + allow_delete: bool = True + plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + datasources: list[DatasourceApiEntity] = Field(default_factory=list) + labels: list[str] = Field(default_factory=list) + + @field_validator("datasources", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + + def to_dict(self) -> dict: + # ------------- + # overwrite datasource parameter types for temp fix + datasources = jsonable_encoder(self.datasources) + for datasource in datasources: + if datasource.get("parameters"): + for parameter in datasource.get("parameters"): + if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: + parameter["type"] = "files" + # ------------- + + return { + "id": self.id, + "author": self.author, + "name": self.name, + "plugin_id": self.plugin_id, + "plugin_unique_identifier": self.plugin_unique_identifier, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "datasources": datasources, + "labels": self.labels, + } diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 1588cbc3c7..40e753671c 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -5,6 +5,7 @@ from typing import Generic, Optional, TypeVar from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity @@ -46,6 +47,13 @@ class PluginToolProviderEntity(BaseModel): declaration: ToolProviderEntityWithPlugin +class PluginDatasourceProviderEntity(BaseModel): + provider: str + plugin_unique_identifier: str + plugin_id: str + declaration: DatasourceProviderEntityWithPlugin + + class PluginAgentProviderEntity(BaseModel): provider: str plugin_unique_identifier: str diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..f4360a70de 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -4,7 +4,11 @@ from typing import Any, Optional from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.entities.plugin_daemon import ( + PluginBasicBooleanResponse, + PluginDatasourceProviderEntity, + PluginToolProviderEntity, +) from core.plugin.impl.base import BasePluginClient from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -41,6 +45,37 @@ class PluginToolManager(BasePluginClient): return response + def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: + """ + Fetch datasources for the given tenant. + """ + + def transformer(json_response: dict[str, Any]) -> dict: + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for tool in declaration.get("tools", []): + tool["identity"]["provider"] = provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + list[PluginToolProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + + for provider in response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.tools: + tool.identity.provider = provider.declaration.identity.name + + return response + def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: """ Fetch tool provider for the given tenant and plugin. @@ -197,6 +232,36 @@ class PluginToolManager(BasePluginClient): return False + def validate_datasource_credentials( + self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + ) -> bool: + """ + validate the credentials of the datasource + """ + tool_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/validate_credentials", + PluginBasicBooleanResponse, + data={ + "user_id": user_id, + "data": { + "provider": tool_provider_id.provider_name, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": tool_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp.result + + return False + def get_runtime_parameters( self, tenant_id: str, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index aa2661fe63..682a32d26f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -495,6 +496,31 @@ class ToolManager: # get plugin providers yield from cls.list_plugin_providers(tenant_id) + @classmethod + def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: + """ + list all the datasource providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_datasources(tenant_id) + return [ + DatasourcePluginProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + for provider in provider_entities + ] + + @classmethod + def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]: + """ + list all the builtin datasources + """ + # get builtin datasources + yield from cls.list_datasource_providers(tenant_id) + @classmethod def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 635748799b..05661a6cc8 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -10,8 +10,8 @@ class RerankingModelConfig(BaseModel): Reranking Model Config. """ - provider: str - model: str + reranking_provider_name: str + reranking_model_name: str class VectorSetting(BaseModel): diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 67d183c70d..9d34734af7 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -56,6 +56,8 @@ external_knowledge_info_fields = { doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} +icon_info_fields = {"icon_type": fields.String, "icon": fields.String, "icon_background": fields.String} + dataset_detail_fields = { "id": fields.String, "name": fields.String, @@ -81,6 +83,10 @@ dataset_detail_fields = { "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)), "built_in_field_enabled": fields.Boolean, + "pipeline_id": fields.String, + "runtime_mode": fields.String, + "chunk_structure": fields.String, + "icon_info": fields.Nested(icon_info_fields), } dataset_query_detail_fields = { diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py new file mode 100644 index 0000000000..0bb74e3259 --- /dev/null +++ b/api/fields/rag_pipeline_fields.py @@ -0,0 +1,163 @@ +from flask_restful import fields # type: ignore + +from fields.workflow_fields import workflow_partial_fields +from libs.helper import AppIconUrlField, TimestampField + +pipeline_detail_kernel_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, +} + +related_app_list = { + "data": fields.List(fields.Nested(pipeline_detail_kernel_fields)), + "total": fields.Integer, +} + +app_detail_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "tracing": fields.Raw, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + + +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + +app_partial_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String(attribute="desc_or_prompt"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), +} + + +app_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), +} + +template_fields = { + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, +} + +template_list_fields = { + "data": fields.List(fields.Nested(template_fields)), +} + +site_fields = { + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +deleted_tool_fields = { + "type": fields.String, + "tool_name": fields.String, + "provider_id": fields.String, +} + +app_detail_fields_with_site = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + + +app_site_fields = { + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, +} + +leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String} + +pipeline_import_fields = { + "id": fields.String, + "status": fields.String, + "pipeline_id": fields.String, + "current_dsl_version": fields.String, + "imported_dsl_version": fields.String, + "error": fields.String, +} + +pipeline_import_check_dependencies_fields = { + "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)), +} diff --git a/api/models/dataset.py b/api/models/dataset.py index 3c44fb4b45..6d23973bba 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -63,6 +63,10 @@ class Dataset(db.Model): # type: ignore[name-defined] collection_binding_id = db.Column(StringUUID, nullable=True) retrieval_model = db.Column(JSONB, nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + icon_info = db.Column(JSONB, nullable=True) + runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) + pipeline_id = db.Column(StringUUID, nullable=True) + chunk_structure = db.Column(db.String(255), nullable=True) @property def dataset_keyword_table(self): diff --git a/api/models/tools.py b/api/models/tools.py index aef1490729..6d08ba61aa 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -51,6 +51,40 @@ class BuiltinToolProvider(Base): return cast(dict, json.loads(self.encrypted_credentials)) +class BuiltinDatasourceProvider(Base): + """ + This table stores the datasource provider information for built-in datasources for each tenant. + """ + + __tablename__ = "tool_builtin_datasource_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_builtin_datasource_provider_pkey"), + # one tenant can only have one tool provider with the same name + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_datasource_provider"), + ) + + # id of the tool provider + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # id of the tenant + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) + # who created this tool provider + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # name of the tool provider + provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + # credential of the tool provider + encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + + @property + def credentials(self) -> dict: + return cast(dict, json.loads(self.encrypted_credentials)) + + class ApiToolProvider(Base): """ The table stores the api providers. diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index af1c1028cf..42748dbf96 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -52,6 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) +from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity from services.errors.account import InvalidActionError, NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError @@ -59,6 +60,7 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService +from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo from services.tag_service import TagService from services.vector_service import VectorService from tasks.batch_clean_document_task import batch_clean_document_task @@ -235,6 +237,63 @@ class DatasetService: db.session.commit() return dataset + @staticmethod + def create_empty_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + dataset = Dataset( + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag_pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + dataset = Dataset( + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag_pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info, + ) + + if rag_pipeline_dataset_create_entity.yaml_content: + rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline( + current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset + ) + return { + "id": rag_pipeline_import_info.id, + "dataset_id": dataset.id, + "pipeline_id": rag_pipeline_import_info.pipeline_id, + "status": rag_pipeline_import_info.status, + "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, + "current_dsl_version": rag_pipeline_import_info.current_dsl_version, + "error": rag_pipeline_import_info.error, + } + @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index d59d47bbce..5f581f1360 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel @@ -14,3 +14,100 @@ class PipelineTemplateInfoEntity(BaseModel): name: str description: str icon_info: IconInfo + + +class RagPipelineDatasetCreateEntity(BaseModel): + name: str + description: str + icon_info: IconInfo + permission: str + partial_member_list: list[str] + yaml_content: str + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + reranking_provider_name: str + reranking_model_name: str + + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class RetrievalSetting(BaseModel): + """ + Retrieval Setting. + """ + + search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] + top_k: int + score_threshold: Optional[float] = 0.5 + score_threshold_enabled: bool = False + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None + + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting + + +class KnowledgeConfiguration(BaseModel): + """ + Knowledge Configuration. + """ + + chunk_structure: str + index_method: IndexMethod + retrieval_setting: RetrievalSetting diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 79f6e79cf5..1e6447d80f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,28 +1,45 @@ import json +import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime from typing import Any, Literal, Optional +from uuid import uuid4 from flask_login import current_user from sqlalchemy import select from sqlalchemy.orm import Session +import contexts from configs import dify_config +from core.model_runtime.utils.encoders import jsonable_encoder +from core.repository.repository_factory import RepositoryFactory +from core.repository.workflow_node_execution_repository import OrderConfig from core.variables.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType +from models.enums import CreatedByRole, WorkflowRunTriggeredFrom +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowType, +) from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError -from services.errors.workflow_service import DraftWorkflowDeletionError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -180,7 +197,6 @@ class RagPipelineService: *, pipeline: Pipeline, graph: dict, - features: dict, unique_hash: Optional[str], account: Account, environment_variables: Sequence[Variable], @@ -197,9 +213,6 @@ class RagPipelineService: if workflow and workflow.unique_hash != unique_hash: raise WorkflowHashNotEqualError() - # validate features structure - self.validate_features_structure(pipeline=pipeline, features=features) - # create draft workflow if not found if not workflow: workflow = Workflow( @@ -208,7 +221,6 @@ class RagPipelineService: type=WorkflowType.RAG_PIPELINE.value, version="draft", graph=json.dumps(graph), - features=json.dumps(features), created_by=account.id, environment_variables=environment_variables, conversation_variables=conversation_variables, @@ -218,7 +230,6 @@ class RagPipelineService: # update draft workflow if found else: workflow.graph = json.dumps(graph) - workflow.features = json.dumps(features) workflow.updated_by = account.id workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables @@ -227,8 +238,8 @@ class RagPipelineService: # commit db session changes db.session.commit() - # trigger app workflow events - app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow) + # trigger workflow events TODO + # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow) # return draft workflow return workflow @@ -269,8 +280,8 @@ class RagPipelineService: # commit db session changes session.add(workflow) - # trigger app workflow events - app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) + # trigger app workflow events TODO + # app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) # return new workflow return workflow @@ -508,46 +519,6 @@ class RagPipelineService: return workflow_node_execution - def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: - """ - Basic mode of chatbot app(expert mode) to workflow - Completion App to Workflow App - - :param app_model: App instance - :param account: Account instance - :param args: dict - :return: - """ - # chatbot convert to workflow mode - workflow_converter = WorkflowConverter() - - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: - raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") - - # convert to workflow - new_app: App = workflow_converter.convert_to_workflow( - app_model=app_model, - account=account, - name=args.get("name", "Default Name"), - icon_type=args.get("icon_type", "emoji"), - icon=args.get("icon", "🤖"), - icon_background=args.get("icon_background", "#FFEAD5"), - ) - - return new_app - - def validate_features_structure(self, app_model: App, features: dict) -> dict: - if app_model.mode == AppMode.ADVANCED_CHAT.value: - return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - elif app_model.mode == AppMode.WORKFLOW.value: - return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - else: - raise ValueError(f"Invalid app mode: {app_model.mode}") - def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict ) -> Optional[Workflow]: @@ -578,38 +549,6 @@ class RagPipelineService: return workflow - def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool: - """ - Delete a workflow - - :param session: SQLAlchemy database session - :param workflow_id: Workflow ID - :param tenant_id: Tenant ID - :return: True if successful - :raises: ValueError if workflow not found - :raises: WorkflowInUseError if workflow is in use - :raises: DraftWorkflowDeletionError if workflow is a draft version - """ - stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) - workflow = session.scalar(stmt) - - if not workflow: - raise ValueError(f"Workflow with ID {workflow_id} not found") - - # Check if workflow is a draft version - if workflow.version == "draft": - raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") - - # Check if this workflow is currently referenced by an app - stmt = select(App).where(App.workflow_id == workflow_id) - app = session.scalar(stmt) - if app: - # Cannot delete a workflow that's currently in use by an app - raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") - - session.delete(workflow) - return True - def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: """ Get second step parameters of rag pipeline @@ -627,3 +566,101 @@ class RagPipelineService: datasource_provider_variables = pipeline_variables.get(datasource_provider, []) shared_variables = pipeline_variables.get("shared", []) return datasource_provider_variables + shared_variables + + def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get("limit", 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == pipeline.tenant_id, + WorkflowRun.app_id == pipeline.id, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + ) + + if args.get("last_id"): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get("last_id"), + ).first() + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) + else: + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id, + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == pipeline.tenant_id, + WorkflowRun.app_id == pipeline.id, + WorkflowRun.id == run_id, + ) + .first() + ) + + return workflow_run + + def get_rag_pipeline_workflow_run_node_executions( + self, + pipeline: Pipeline, + run_id: str, + ) -> list[WorkflowNodeExecution]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + if not workflow_run: + return [] + + # Use the repository to get the node executions + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": pipeline.tenant_id, + "app_id": pipeline.id, + "session_factory": db.session.get_bind(), + } + ) + + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + + return list(node_executions) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py new file mode 100644 index 0000000000..80e7c6af0b --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -0,0 +1,841 @@ +import base64 +import hashlib +import logging +import uuid +from collections.abc import Mapping +from enum import StrEnum +from typing import Optional +from urllib.parse import urlparse +from uuid import uuid4 + +import yaml # type: ignore +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad +from packaging import version +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import PluginDependency +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData +from core.workflow.nodes.tool.entities import ToolNodeData +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account +from models.dataset import Dataset, Pipeline +from models.workflow import Workflow +from services.dataset_service import DatasetCollectionBindingService +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration +from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.rag_pipeline.rag_pipeline import RagPipelineService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" +CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" +IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes +DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB +CURRENT_DSL_VERSION = "0.1.0" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class RagPipelineImportInfo(BaseModel): + id: str + status: ImportStatus + pipeline_id: Optional[str] = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + dataset_id: Optional[str] = None + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # If imported version is newer than current, always return PENDING + if imported_ver > current_ver: + return ImportStatus.PENDING + + # If imported version is older than current's major, return PENDING + if imported_ver.major < current_ver.major: + return ImportStatus.PENDING + + # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS + if imported_ver.minor < current_ver.minor: + return ImportStatus.COMPLETED_WITH_WARNINGS + + # If imported version equals or is older than current's micro, return COMPLETED + return ImportStatus.COMPLETED + + +class RagPipelinePendingData(BaseModel): + import_mode: str + yaml_content: str + name: str | None + description: str | None + icon_type: str | None + icon: str | None + icon_background: str | None + pipeline_id: str | None + + +class CheckDependenciesPendingData(BaseModel): + dependencies: list[PluginDependency] + pipeline_id: str | None + + +class RagPipelineDslService: + def __init__(self, session: Session): + self._session = session + + def import_rag_pipeline( + self, + *, + account: Account, + import_mode: str, + yaml_content: Optional[str] = None, + yaml_url: Optional[str] = None, + pipeline_id: Optional[str] = None, + dataset: Optional[Dataset] = None, + ) -> RagPipelineImportInfo: + """Import an app from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content: str = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + parsed_url = urlparse(yaml_url) + if ( + parsed_url.scheme == "https" + and parsed_url.netloc == "github.com" + and parsed_url.path.endswith((".yml", ".yaml")) + ): + yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") + yaml_url = yaml_url.replace("/blob/", "/") + response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content.decode() + + if len(content) > DSL_MAX_SIZE: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="File size exceeds the limit of 10MB", + ) + + if not content: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Empty content from url", + ) + except Exception as e: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error fetching YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + + # Process YAML content + try: + # Parse YAML to validate format + data = yaml.safe_load(content) + if not isinstance(data, dict): + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: content must be a mapping", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + if not data.get("kind") or data.get("kind") != "rag-pipeline": + data["kind"] = "rag-pipeline" + + imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") + status = _check_version_compatibility(imported_version) + + # Extract app data + pipeline_data = data.get("pipeline") + if not pipeline_data: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing pipeline data in YAML content", + ) + + # If app_id is provided, check if it exists + pipeline = None + if pipeline_id: + stmt = select(Pipeline).where( + Pipeline.id == pipeline_id, + Pipeline.tenant_id == account.current_tenant_id, + ) + pipeline = self._session.scalar(stmt) + + if not pipeline: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Pipeline not found", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + pending_data = RagPipelinePendingData( + import_mode=import_mode, + yaml_content=content, + pipeline_id=pipeline_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending_data.model_dump_json(), + ) + + return RagPipelineImportInfo( + id=import_id, + status=status, + pipeline_id=pipeline_id, + imported_dsl_version=imported_version, + ) + + # Extract dependencies + dependencies = data.get("dependencies", []) + check_dependencies_pending_data = None + if dependencies: + check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] + + # Create or update app + pipeline = self._create_or_update_pipeline( + pipeline=pipeline, + data=data, + account=account, + dependencies=check_dependencies_pending_data, + ) + # create dataset + name = pipeline.name + description = pipeline.description + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") + workflow = data.get("workflow", {}) + graph = workflow.get("graph", {}) + nodes = graph.get("nodes", []) + dataset_id = None + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + if not dataset: + dataset = Dataset( + tenant_id=account.current_tenant_id, + name=name, + description=description, + icon_info={ + "type": icon_type, + "icon": icon, + "background": icon_background, + "url": icon_url, + }, + indexing_technique=knowledge_configuration.index_method.indexing_technique, + created_by=account.id, + retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + runtime_mode="rag_pipeline", + chunk_structure=knowledge_configuration.chunk_structure, + ) + else: + dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.runtime_mode = "rag_pipeline" + dataset.chunk_structure = knowledge_configuration.chunk_structure + if knowledge_configuration.index_method.indexing_technique == "high_quality": + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore + knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore + ) + dataset_collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding_id + dataset.embedding_model = ( + knowledge_configuration.index_method.embedding_setting.embedding_model_name + ) + dataset.embedding_model_provider = ( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name + ) + elif knowledge_configuration.index_method.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + dataset.pipeline_id = pipeline.id + self._session.add(dataset) + self._session.commit() + dataset_id = dataset.id + if not dataset_id: + raise ValueError("DSL is not valid, please check the Knowledge Index node.") + + return RagPipelineImportInfo( + id=import_id, + status=status, + pipeline_id=pipeline.id, + dataset_id=dataset_id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import app") + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + pending_data = RagPipelinePendingData.model_validate_json(pending_data) + data = yaml.safe_load(pending_data.yaml_content) + + pipeline = None + if pending_data.pipeline_id: + stmt = select(Pipeline).where( + Pipeline.id == pending_data.pipeline_id, + Pipeline.tenant_id == account.current_tenant_id, + ) + pipeline = self._session.scalar(stmt) + + # Create or update app + pipeline = self._create_or_update_pipeline( + pipeline=pipeline, + data=data, + account=account, + ) + + # create dataset + name = pipeline.name + description = pipeline.description + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") + workflow = data.get("workflow", {}) + graph = workflow.get("graph", {}) + nodes = graph.get("nodes", []) + dataset_id = None + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + if not dataset: + dataset = Dataset( + tenant_id=account.current_tenant_id, + name=name, + description=description, + icon_info={ + "type": icon_type, + "icon": icon, + "background": icon_background, + "url": icon_url, + }, + indexing_technique=knowledge_configuration.index_method.indexing_technique, + created_by=account.id, + retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + runtime_mode="rag_pipeline", + chunk_structure=knowledge_configuration.chunk_structure, + ) + else: + dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.runtime_mode = "rag_pipeline" + dataset.chunk_structure = knowledge_configuration.chunk_structure + if knowledge_configuration.index_method.indexing_technique == "high_quality": + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore + knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore + ) + dataset_collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding_id + dataset.embedding_model = ( + knowledge_configuration.index_method.embedding_setting.embedding_model_name + ) + dataset.embedding_model_provider = ( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name + ) + elif knowledge_configuration.index_method.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + dataset.pipeline_id = pipeline.id + self._session.add(dataset) + self._session.commit() + dataset_id = dataset.id + if not dataset_id: + raise ValueError("DSL is not valid, please check the Knowledge Index node.") + + # Delete import info from Redis + redis_client.delete(redis_key) + + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.COMPLETED, + pipeline_id=pipeline.id, + dataset_id=dataset_id, + current_dsl_version=CURRENT_DSL_VERSION, + imported_dsl_version=data.get("version", "0.1.0"), + ) + + except Exception as e: + logger.exception("Error confirming import") + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def check_dependencies( + self, + *, + pipeline: Pipeline, + ) -> CheckDependenciesResult: + """Check dependencies""" + # Get dependencies from Redis + redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}" + dependencies = redis_client.get(redis_key) + if not dependencies: + return CheckDependenciesResult() + + # Extract dependencies + dependencies = CheckDependenciesPendingData.model_validate_json(dependencies) + + # Get leaked dependencies + leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies + ) + return CheckDependenciesResult( + leaked_dependencies=leaked_dependencies, + ) + + def _create_or_update_pipeline( + self, + *, + pipeline: Optional[Pipeline], + data: dict, + account: Account, + dependencies: Optional[list[PluginDependency]] = None, + ) -> Pipeline: + """Create a new app or update an existing one.""" + pipeline_data = data.get("pipeline", {}) + pipeline_mode = pipeline_data.get("mode") + if not pipeline_mode: + raise ValueError("loss pipeline mode") + # Set icon type + icon_type_value = icon_type or pipeline_data.get("icon_type") + if icon_type_value in ["emoji", "link"]: + icon_type = icon_type_value + else: + icon_type = "emoji" + icon = icon or str(pipeline_data.get("icon", "")) + + if pipeline: + # Update existing pipeline + pipeline.name = pipeline_data.get("name", pipeline.name) + pipeline.description = pipeline_data.get("description", pipeline.description) + pipeline.icon_type = icon_type + pipeline.icon = icon + pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background) + pipeline.updated_by = account.id + else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + + # Create new app + pipeline = Pipeline() + pipeline.id = str(uuid4()) + pipeline.tenant_id = account.current_tenant_id + pipeline.mode = pipeline_mode.value + pipeline.name = pipeline_data.get("name", "") + pipeline.description = pipeline_data.get("description", "") + pipeline.icon_type = icon_type + pipeline.icon = icon + pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF") + pipeline.enable_site = True + pipeline.enable_api = True + pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False) + pipeline.created_by = account.id + pipeline.updated_by = account.id + + self._session.add(pipeline) + self._session.commit() + # save dependencies + if dependencies: + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}", + IMPORT_INFO_REDIS_EXPIRY, + CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), + ) + + # Initialize pipeline based on mode + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for rag pipeline") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) + rag_pipeline_variables = [ + variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list + ] + + rag_pipeline_service = RagPipelineService() + current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, + tenant_id=pipeline.tenant_id, + ) + ) + ] + rag_pipeline_service.sync_draft_workflow( + pipeline=pipeline, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + + return pipeline + + @classmethod + def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str: + """ + Export pipeline + :param pipeline: Pipeline instance + :param include_secret: Whether include secret variable + :return: + """ + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "rag_pipeline", + "pipeline": { + "name": pipeline.name, + "mode": pipeline.mode, + "icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon, + "icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background, + "description": pipeline.description, + "use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon, + }, + } + + cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) + + return yaml.dump(export_data, allow_unicode=True) # type: ignore + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param pipeline: Pipeline instance + """ + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + workflow_dict = workflow.to_dict(include_secret=include_secret) + for node in workflow_dict.get("graph", {}).get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) + for dataset_id in dataset_ids + ] + export_data["workflow"] = workflow_dict + dependencies = cls._extract_dependencies_from_workflow(workflow) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None: + """ + Append model config export data + :param export_data: export data + :param pipeline: Pipeline instance + """ + app_model_config = pipeline.app_model_config + if not app_model_config: + raise ValueError("Missing app configuration, please check.") + + export_data["model_config"] = app_model_config.to_dict() + dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: + """ + Extract dependencies from workflow + :param workflow: Workflow instance + :return: dependencies list format like ["langgenius/google"] + """ + graph = workflow.graph_dict + dependencies = cls._extract_dependencies_from_workflow_graph(graph) + return dependencies + + @classmethod + def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]: + """ + Extract dependencies from workflow graph + :param graph: Workflow graph + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + for node in graph.get("nodes", []): + try: + typ = node.get("data", {}).get("type") + match typ: + case NodeType.TOOL.value: + tool_entity = ToolNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), + ) + case NodeType.LLM.value: + llm_entity = LLMNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), + ) + case NodeType.QUESTION_CLASSIFIER.value: + question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + question_classifier_entity.model.provider + ), + ) + case NodeType.PARAMETER_EXTRACTOR.value: + parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + parameter_extractor_entity.model.provider + ), + ) + case NodeType.KNOWLEDGE_RETRIEVAL.value: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + if knowledge_retrieval_entity.retrieval_mode == "multiple": + if knowledge_retrieval_entity.multiple_retrieval_config: + if ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "reranking_model" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider + ), + ) + elif ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "weighted_score" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.weights: + vector_setting = ( + knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting + ) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + vector_setting.embedding_provider_name + ), + ) + elif knowledge_retrieval_entity.retrieval_mode == "single": + model_config = knowledge_retrieval_entity.single_retrieval_config + if model_config: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + model_config.model.provider + ), + ) + case _: + # TODO: Handle default case or unknown node types + pass + except Exception as e: + logger.exception("Error extracting node dependency", exc_info=e) + + return dependencies + + @classmethod + def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]: + """ + Extract dependencies from model config + :param model_config: model config dict + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + + try: + # completion model + model_dict = model_config.get("model", {}) + if model_dict: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", "")) + ) + + # reranking model + dataset_configs = model_config.get("dataset_configs", {}) + if dataset_configs: + for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []): + if dataset_config.get("reranking_model"): + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + dataset_config.get("reranking_model", {}) + .get("reranking_provider_name", {}) + .get("provider") + ) + ) + + # tools + agent_configs = model_config.get("agent_mode", {}) + if agent_configs: + for agent_config in agent_configs.get("tools", []): + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id")) + ) + + except Exception as e: + logger.exception("Error extracting model config dependency", exc_info=e) + + return dependencies + + @classmethod + def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + """ + Returns the leaked dependencies in current workspace + """ + dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + if not dependencies: + return [] + + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + + @staticmethod + def _generate_aes_key(tenant_id: str) -> bytes: + """Generate AES key based on tenant_id""" + return hashlib.sha256(tenant_id.encode()).digest() + + @classmethod + def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: + """Encrypt dataset_id using AES-CBC mode""" + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) + return base64.b64encode(ct_bytes).decode() + + @classmethod + def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: + """AES decryption""" + try: + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) + return pt.decode() + except Exception: + return None diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3ccd14415d..daf3773309 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,6 +5,7 @@ from pathlib import Path from sqlalchemy.orm import Session from configs import dify_config +from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -16,7 +17,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from models.tools import BuiltinDatasourceProvider, BuiltinToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -286,6 +287,67 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) + @staticmethod + def list_rag_pipeline_datasources(tenant_id: str) -> list[DatasourceProviderApiEntity]: + """ + list rag pipeline datasources + """ + # get all builtin providers + datasource_provider_controllers = ToolManager.list_datasource_providers(tenant_id) + + with db.session.no_autoflush: + # get all user added providers + db_providers: list[BuiltinDatasourceProvider] = ( + db.session.query(BuiltinDatasourceProvider) + .filter(BuiltinDatasourceProvider.tenant_id == tenant_id) + .all() + or [] + ) + + # find provider + def find_provider(provider): + return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + + result: list[DatasourceProviderApiEntity] = [] + + for provider_controller in datasource_provider_controllers: + try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue + + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_datasource_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.entity.identity.name), + decrypt_credentials=True, + ) + + # add icon + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) + + datasources = provider_controller.get_datasources() + for datasource in datasources or []: + user_builtin_provider.datasources.append( + ToolTransformService.convert_datasource_entity_to_api_entity( + tenant_id=tenant_id, + datasource=datasource, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) + + result.append(user_builtin_provider) + except Exception as e: + raise e + + return result + @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..e0c1ce7217 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,11 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import DatasourceProviderType from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -21,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinDatasourceProvider, BuiltinToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -140,6 +145,64 @@ class ToolTransformService: return result + @classmethod + def builtin_datasource_provider_to_user_provider( + cls, + provider_controller: DatasourcePluginProviderController, + db_provider: Optional[BuiltinDatasourceProvider], + decrypt_credentials: bool = True, + ) -> DatasourceProviderApiEntity: + """ + convert provider controller to user provider + """ + result = DatasourceProviderApiEntity( + id=provider_controller.entity.identity.name, + author=provider_controller.entity.identity.author, + name=provider_controller.entity.identity.name, + description=provider_controller.entity.identity.description, + icon=provider_controller.entity.identity.icon, + label=provider_controller.entity.identity.label, + type=DatasourceProviderType.RAG_PIPELINE, + masked_credentials={}, + is_team_authorization=False, + plugin_id=provider_controller.plugin_id, + plugin_unique_identifier=provider_controller.plugin_unique_identifier, + datasources=[], + ) + + # get credentials schema + schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + + for name, value in schema.items(): + if result.masked_credentials: + result.masked_credentials[name] = "" + + # check if the provider need credentials + if not provider_controller.need_credentials: + result.is_team_authorization = True + result.allow_delete = False + elif db_provider: + result.is_team_authorization = True + + if decrypt_credentials: + credentials = db_provider.credentials + + # init tool configuration + tool_configuration = ProviderConfigEncrypter( + tenant_id=db_provider.tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt(data=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + + result.masked_credentials = masked_credentials + result.original_credentials = decrypted_credentials + + return result + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -304,3 +367,48 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + + @staticmethod + def convert_datasource_entity_to_api_entity( + datasource: DatasourcePlugin, + tenant_id: str, + credentials: dict | None = None, + labels: list[str] | None = None, + ) -> DatasourceApiEntity: + """ + convert tool to user tool + """ + # fork tool runtime + datasource = datasource.fork_datasource_runtime( + runtime=DatasourceRuntime( + credentials=credentials or {}, + tenant_id=tenant_id, + ) + ) + + # get datasource parameters + parameters = datasource.entity.parameters or [] + # get datasource runtime parameters + runtime_parameters = datasource.get_runtime_parameters() + # override parameters + current_parameters = parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return DatasourceApiEntity( + author=datasource.entity.identity.author, + name=datasource.entity.identity.name, + label=datasource.entity.identity.label, + description=datasource.entity.description.human if datasource.entity.description else I18nObject(en_US=""), + output_schema=datasource.entity.output_schema, + parameters=current_parameters, + labels=labels or [], + ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 63e3791147..c0f4578474 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -203,7 +203,6 @@ class WorkflowService: type=draft_workflow.type, version=str(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, - features=draft_workflow.features, created_by=account.id, environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables,