diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 8be82c1e40..0442a121c0 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,7 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" - - "feat/rag-pipeline" + - "deploy/rag-dev" tags: - "*" diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 47ca03c2eb..0d99c6fa58 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -4,7 +4,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "deploy/dev" + - "deploy/rag-dev" types: - completed @@ -12,12 +12,13 @@ jobs: deploy: runs-on: ubuntu-latest if: | - github.event.workflow_run.conclusion == 'success' + github.event.workflow_run.conclusion == 'success' && + github.event.workflow_run.head_branch == 'deploy/rag-dev' steps: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 with: - host: ${{ secrets.SSH_HOST }} + host: ${{ secrets.RAG_SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | diff --git a/api/app.py b/api/app.py index 4f393f6c20..e0a903b10d 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,3 @@ -import os import sys @@ -17,20 +16,20 @@ else: # It seems that JetBrains Python debugger does not work well with gevent, # so we need to disable gevent in debug mode. # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. - if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: - from gevent import monkey + # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: + # from gevent import monkey + # + # # gevent + # monkey.patch_all() + # + # from grpc.experimental import gevent as grpc_gevent # type: ignore + # + # # grpc gevent + # grpc_gevent.init_gevent() - # gevent - monkey.patch_all() - - from grpc.experimental import gevent as grpc_gevent # type: ignore - - # grpc gevent - grpc_gevent.init_gevent() - - import psycogreen.gevent # type: ignore - - psycogreen.gevent.patch_psycopg() + # import psycogreen.gevent # type: ignore + # + # psycogreen.gevent.patch_psycopg() from app_factory import create_app diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 18ef1ed45b..3e57f24ff5 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings): ) +class HostedFetchPipelineTemplateConfig(BaseSettings): + """ + Configuration for fetching pipeline templates + """ + + HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field( + description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,", + default="database", + ) + + HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field( + description="Domain for fetching remote pipeline templates", + default="https://tmpl.dify.ai", + ) + + class HostedServiceConfig( # place the configs in alphabet order HostedAnthropicConfig, HostedAzureOpenAiConfig, HostedFetchAppTemplateConfig, + HostedFetchPipelineTemplateConfig, HostedMinmaxConfig, HostedOpenAiConfig, HostedSparkConfig, diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index ae41a2c03a..8be769e798 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -3,6 +3,7 @@ from threading import Lock from typing import TYPE_CHECKING from contexts.wrapper import RecyclableContextVar +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity @@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( ContextVar("plugin_model_schemas") ) + +datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( + RecyclableContextVar(ContextVar("datasource_plugin_providers")) +) + +datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( + ContextVar("datasource_plugin_providers_lock") +) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index dbdcdc46ce..312a870472 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -76,7 +76,6 @@ from .billing import billing, compliance # Import datasets controllers from .datasets import ( - data_source, datasets, datasets_document, datasets_segments, @@ -85,6 +84,14 @@ from .datasets import ( metadata, website, ) +from .datasets.rag_pipeline import ( + datasource_auth, + datasource_content_preview, + rag_pipeline, + rag_pipeline_datasets, + rag_pipeline_import, + rag_pipeline_workflow, +) # Import explore controllers from .explore import ( diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1611214cb3..3cb32a28ba 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -283,6 +283,15 @@ class DatasetApi(Resource): location="json", help="Invalid external knowledge api id.", ) + + parser.add_argument( + "icon_info", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid icon info.", + ) args = parser.parse_args() data = request.get_json() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b2fcf3ce7b..35d912bfcc 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,3 +1,4 @@ +import json import logging from argparse import ArgumentTypeError from datetime import UTC, datetime @@ -51,6 +52,7 @@ from fields.document_fields import ( ) from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig @@ -661,7 +663,7 @@ class DocumentDetailApi(DocumentResource): response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -1028,6 +1030,41 @@ class WebsiteDocumentSyncApi(DocumentResource): return {"result": "success"}, 200 +class DocumentPipelineExecutionLogApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound("Document not found.") + log = ( + db.session.query(DocumentPipelineExecutionLog) + .filter_by(document_id=document_id) + .order_by(DocumentPipelineExecutionLog.created_at.desc()) + .first() + ) + if not log: + return { + "datasource_info": None, + "datasource_type": None, + "input_data": None, + "datasource_node_id": None, + }, 200 + return { + "datasource_info": json.loads(log.datasource_info), + "datasource_type": log.datasource_type, + "input_data": log.input_data, + "datasource_node_id": log.datasource_node_id, + }, 200 + + api.add_resource(GetProcessRuleApi, "/datasets/process-rule") api.add_resource(DatasetDocumentListApi, "/datasets//documents") api.add_resource(DatasetInitApi, "/datasets/init") @@ -1050,3 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets//retry") api.add_resource(DocumentRenameApi, "/datasets//documents//rename") api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") +api.add_resource( + DocumentPipelineExecutionLogApi, "/datasets//documents//pipeline-execution-log" +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 2f00a84de6..b80c4402cd 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -101,3 +101,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException): error_code = "child_chunk_delete_index_error" description = "Delete child chunk index failed: {message}" code = 500 + + +class PipelineNotFoundError(BaseHTTPException): + error_code = "pipeline_not_found" + description = "Pipeline not found." + code = 404 diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py new file mode 100644 index 0000000000..124d45f513 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -0,0 +1,197 @@ +from flask import redirect, request +from flask_login import current_user # type: ignore +from flask_restful import ( # type: ignore + Resource, # type: ignore + reqparse, +) +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.impl.oauth import OAuthHandler +from extensions.ext_database import db +from libs.login import login_required +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider +from services.datasource_provider_service import DatasourceProviderService + + +class DatasourcePluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all plugin oauth configs + plugin_oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) + if not plugin_oauth_config: + raise NotFound() + oauth_handler = OAuthHandler() + redirect_url = ( + f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" + ) + system_credentials = plugin_oauth_config.system_credentials + if system_credentials: + system_credentials["redirect_url"] = redirect_url + response = oauth_handler.get_authorization_url( + current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials + ) + return response.model_dump() + + +class DatasourceOauthCallback(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] + oauth_handler = OAuthHandler() + plugin_oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) + if not plugin_oauth_config: + raise NotFound() + credentials = oauth_handler.get_credentials( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials, + request=request, + ) + datasource_provider = DatasourceProvider( + plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials + ) + db.session.add(datasource_provider) + db.session.commit() + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + + +class DatasourceAuth(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + datasource_provider_service = DatasourceProviderService() + + try: + datasource_provider_service.datasource_provider_credentials_validate( + tenant_id=current_user.current_tenant_id, + provider=args["provider"], + plugin_id=args["plugin_id"], + credentials=args["credentials"], + name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + datasource_provider_service = DatasourceProviderService() + datasources = datasource_provider_service.get_datasource_credentials( + tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"] + ) + return {"result": datasources}, 200 + + +class DatasourceAuthUpdateDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, auth_id: str): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + if not current_user.is_editor: + raise Forbidden() + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.remove_datasource_credentials( + tenant_id=current_user.current_tenant_id, + auth_id=auth_id, + provider=args["provider"], + plugin_id=args["plugin_id"], + ) + return {"result": "success"}, 200 + + @setup_required + @login_required + @account_initialization_required + def patch(self, auth_id: str): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + if not current_user.is_editor: + raise Forbidden() + try: + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.update_datasource_credentials( + tenant_id=current_user.current_tenant_id, + auth_id=auth_id, + provider=args["provider"], + plugin_id=args["plugin_id"], + credentials=args["credentials"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + +# Import Rag Pipeline +api.add_resource( + DatasourcePluginOauthApi, + "/oauth/plugin/datasource", +) +api.add_resource( + DatasourceOauthCallback, + "/oauth/plugin/datasource/callback", +) +api.add_resource( + DatasourceAuth, + "/auth/plugin/datasource", +) + +api.add_resource( + DatasourceAuthUpdateDeleteApi, + "/auth/plugin/datasource/", +) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py new file mode 100644 index 0000000000..bb02c659b8 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -0,0 +1,55 @@ +from flask_restful import ( # type: ignore + Resource, # type: ignore + reqparse, +) +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import current_user, login_required +from models import Account +from models.dataset import Pipeline +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class DataSourceContentPreviewApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run datasource content preview + """ + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs is None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type is None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + preview_content = rag_pipeline_service.run_datasource_node_preview( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=True, + ) + return preview_content, 200 + + +api.add_resource( + DataSourceContentPreviewApi, + "/rag/pipelines//workflows/published/datasource/nodes//preview", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py new file mode 100644 index 0000000000..93976bd6f5 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -0,0 +1,162 @@ +import logging + +from flask import request +from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session + +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) +from extensions.ext_database import db +from libs.login import login_required +from models.dataset import PipelineCustomizedTemplate +from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.rag_pipeline.rag_pipeline import RagPipelineService + +logger = logging.getLogger(__name__) + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class PipelineTemplateListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self): + type = request.args.get("type", default="built-in", type=str) + language = request.args.get("language", default="en-US", type=str) + # get pipeline templates + pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) + return pipeline_templates, 200 + + +class PipelineTemplateDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self, template_id: str): + type = request.args.get("type", default="built-in", type=str) + rag_pipeline_service = RagPipelineService() + pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) + return pipeline_template, 200 + + +class CustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def patch(self, template_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + pipeline_template_info = PipelineTemplateInfoEntity(**args) + RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + return 200 + + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def delete(self, template_id: str): + RagPipelineService.delete_customized_pipeline_template(template_id) + return 200 + + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, template_id: str): + with Session(db.engine) as session: + template = ( + session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + ) + if not template: + raise ValueError("Customized pipeline template not found.") + + return {"data": template.yaml_content}, 200 + + +class PublishCustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, pipeline_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + rag_pipeline_service = RagPipelineService() + rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) + return {"result": "success"} + + +api.add_resource( + PipelineTemplateListApi, + "/rag/pipeline/templates", +) +api.add_resource( + PipelineTemplateDetailApi, + "/rag/pipeline/templates/", +) +api.add_resource( + CustomizedPipelineTemplateApi, + "/rag/pipeline/customized/templates/", +) +api.add_resource( + PublishCustomizedPipelineTemplateApi, + "/rag/pipelines//customized/publish", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py new file mode 100644 index 0000000000..f502157eda --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -0,0 +1,171 @@ +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore +from werkzeug.exceptions import Forbidden + +import services +from controllers.console import api +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + setup_required, +) +from fields.dataset_fields import dataset_detail_fields +from libs.login import login_required +from models.dataset import DatasetPermissionEnum +from services.dataset_service import DatasetPermissionService, DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class CreateRagPipelineDatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + + parser.add_argument( + "icon_info", + type=dict, + nullable=True, + required=False, + default={}, + ) + + parser.add_argument( + "permission", + type=str, + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + nullable=True, + required=False, + default=DatasetPermissionEnum.ONLY_ME, + ) + + parser.add_argument( + "partial_member_list", + type=list, + nullable=True, + required=False, + default=[], + ) + + parser.add_argument( + "yaml_content", + type=str, + nullable=False, + required=True, + help="yaml_content is required.", + ) + + args = parser.parse_args() + + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args) + try: + import_info = RagPipelineDslService.create_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, + ) + if rag_pipeline_dataset_create_entity.permission == "partial_members": + DatasetPermissionService.update_partial_member_list( + current_user.current_tenant_id, + import_info["dataset_id"], + rag_pipeline_dataset_create_entity.partial_member_list, + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return import_info, 201 + + +class CreateEmptyRagPipelineDatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self): + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + + parser.add_argument( + "icon_info", + type=dict, + nullable=True, + required=False, + default={}, + ) + + parser.add_argument( + "permission", + type=str, + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + nullable=True, + required=False, + default=DatasetPermissionEnum.ONLY_ME, + ) + + parser.add_argument( + "partial_member_list", + type=list, + nullable=True, + required=False, + default=[], + ) + + args = parser.parse_args() + dataset = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args), + ) + return marshal(dataset, dataset_detail_fields), 201 + + +api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset") +api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py new file mode 100644 index 0000000000..e5c211be93 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -0,0 +1,146 @@ +from typing import cast + +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from extensions.ext_database import db +from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from libs.login import login_required +from models import Account +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + +class RagPipelineImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(pipeline_import_fields) + def post(self): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("mode", type=str, required=True, location="json") + parser.add_argument("yaml_content", type=str, location="json") + parser.add_argument("yaml_url", type=str, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("pipeline_id", type=str, location="json") + args = parser.parse_args() + + # Create service with session + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + # Import app + account = cast(Account, current_user) + result = import_service.import_rag_pipeline( + account=account, + import_mode=args["mode"], + yaml_content=args.get("yaml_content"), + yaml_url=args.get("yaml_url"), + pipeline_id=args.get("pipeline_id"), + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING.value: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +class RagPipelineImportConfirmApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(pipeline_import_fields) + def post(self, import_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + # Create service with session + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + # Confirm import + account = cast(Account, current_user) + result = import_service.confirm_import(import_id=import_id, account=account) + session.commit() + + # Return appropriate status code based on result + if result.status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 + + +class RagPipelineImportCheckDependenciesApi(Resource): + @setup_required + @login_required + @get_rag_pipeline + @account_initialization_required + @marshal_with(pipeline_import_check_dependencies_fields) + def get(self, pipeline: Pipeline): + if not current_user.is_editor: + raise Forbidden() + + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + result = import_service.check_dependencies(pipeline=pipeline) + + return result.model_dump(mode="json"), 200 + + +class RagPipelineExportApi(Resource): + @setup_required + @login_required + @get_rag_pipeline + @account_initialization_required + def get(self, pipeline: Pipeline): + if not current_user.is_editor: + raise Forbidden() + + # Add include_secret params + parser = reqparse.RequestParser() + parser.add_argument("include_secret", type=bool, default=False, location="args") + args = parser.parse_args() + + with Session(db.engine) as session: + export_service = RagPipelineDslService(session) + result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"]) + + return {"data": result}, 200 + + +# Import Rag Pipeline +api.add_resource( + RagPipelineImportApi, + "/rag/pipelines/imports", +) +api.add_resource( + RagPipelineImportConfirmApi, + "/rag/pipelines/imports//confirm", +) +api.add_resource( + RagPipelineImportCheckDependenciesApi, + "/rag/pipelines/imports//check-dependencies", +) +api.add_resource( + RagPipelineExportApi, + "/rag/pipelines//exports", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py new file mode 100644 index 0000000000..8bae9dc466 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -0,0 +1,1070 @@ +import json +import logging +from typing import cast + +from flask import abort, request +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from configs import dify_config +from controllers.console import api +from controllers.console.app.error import ( + ConversationCompletedError, + DraftWorkflowNotExist, + DraftWorkflowNotSync, +) +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db +from factories import variable_factory +from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from fields.workflow_run_fields import ( + workflow_run_detail_fields, + workflow_run_node_execution_fields, + workflow_run_node_execution_list_fields, + workflow_run_pagination_fields, +) +from libs import helper +from libs.helper import TimestampField, uuid_value +from libs.login import current_user, login_required +from models.account import Account +from models.dataset import Pipeline +from models.model import EndUser +from services.errors.app import WorkflowHashNotEqualError +from services.errors.llm import InvokeRateLimitError +from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService +from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService + +logger = logging.getLogger(__name__) + + +class DraftRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_fields) + def get(self, pipeline: Pipeline): + """ + Get draft rag pipeline's workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + # fetch draft workflow by app_model + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + if not workflow: + raise DraftWorkflowNotExist() + + # return workflow, if not found, return None (initiate graph by frontend) + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Sync draft workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: + parser = reqparse.RequestParser() + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") + parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") + args = parser.parse_args() + elif "text/plain" in content_type: + try: + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") + + if not isinstance(data.get("graph"), dict): + raise ValueError("graph is not a dict") + + args = { + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), + "rag_pipeline_variables": data.get("rag_pipeline_variables"), + } + except json.JSONDecodeError: + return {"message": "Invalid JSON data"}, 400 + else: + abort(415) + + if not isinstance(current_user, Account): + raise Forbidden() + + try: + environment_variables_list = args.get("environment_variables") or [] + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.sync_draft_workflow( + pipeline=pipeline, + graph=args["graph"], + unique_hash=args.get("hash"), + account=current_user, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + rag_pipeline_variables=args.get("rag_pipeline_variables") or [], + ) + except WorkflowHashNotEqualError: + raise DraftWorkflowNotSync() + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + +class RagPipelineDraftRunIterationNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run draft workflow iteration node + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, location="json") + args = parser.parse_args() + + try: + response = PipelineGenerateService.generate_single_iteration( + pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +class RagPipelineDraftRunLoopNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run draft workflow loop node + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, location="json") + args = parser.parse_args() + + try: + response = PipelineGenerateService.generate_single_loop( + pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +class DraftRagPipelineRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Run draft workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") + args = parser.parse_args() + + try: + response = PipelineGenerateService.generate( + pipeline=pipeline, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + return helper.compact_generate_response(response) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + + +class PublishedRagPipelineRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Run published workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") + parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) + parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") + args = parser.parse_args() + + streaming = args["response_mode"] == "streaming" + + try: + response = PipelineGenerateService.generate( + pipeline=pipeline, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, + streaming=streaming, + ) + + return helper.compact_generate_response(response) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + + +# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): +# @setup_required +# @login_required +# @account_initialization_required +# @get_rag_pipeline +# def post(self, pipeline: Pipeline, node_id: str): +# """ +# Run rag pipeline datasource +# """ +# # The role of the current user in the ta table must be admin, owner, or editor +# if not current_user.is_editor: +# raise Forbidden() +# +# if not isinstance(current_user, Account): +# raise Forbidden() +# +# parser = reqparse.RequestParser() +# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") +# parser.add_argument("datasource_type", type=str, required=True, location="json") +# args = parser.parse_args() +# +# job_id = args.get("job_id") +# if job_id == None: +# raise ValueError("missing job_id") +# datasource_type = args.get("datasource_type") +# if datasource_type == None: +# raise ValueError("missing datasource_type") +# +# rag_pipeline_service = RagPipelineService() +# result = rag_pipeline_service.run_datasource_workflow_node_status( +# pipeline=pipeline, +# node_id=node_id, +# job_id=job_id, +# account=current_user, +# datasource_type=datasource_type, +# is_published=True +# ) +# +# return result + + +# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): +# @setup_required +# @login_required +# @account_initialization_required +# @get_rag_pipeline +# def post(self, pipeline: Pipeline, node_id: str): +# """ +# Run rag pipeline datasource +# """ +# # The role of the current user in the ta table must be admin, owner, or editor +# if not current_user.is_editor: +# raise Forbidden() +# +# if not isinstance(current_user, Account): +# raise Forbidden() +# +# parser = reqparse.RequestParser() +# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") +# parser.add_argument("datasource_type", type=str, required=True, location="json") +# args = parser.parse_args() +# +# job_id = args.get("job_id") +# if job_id == None: +# raise ValueError("missing job_id") +# datasource_type = args.get("datasource_type") +# if datasource_type == None: +# raise ValueError("missing datasource_type") +# +# rag_pipeline_service = RagPipelineService() +# result = rag_pipeline_service.run_datasource_workflow_node_status( +# pipeline=pipeline, +# node_id=node_id, +# job_id=job_id, +# account=current_user, +# datasource_type=datasource_type, +# is_published=False +# ) +# +# return result +# + + +class RagPipelinePublishedDatasourceNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs is None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type is None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, + ) + ) + ) + + +class RagPipelineDraftDatasourceNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs is None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type is None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, + ) + ) + ) + + +class RagPipelinePublishedNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.run_published_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return workflow_node_execution + + +class RagPipelineDraftNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_node_execution_fields) + def post(self, pipeline: Pipeline, node_id: str): + """ + Run draft workflow node + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return workflow_node_execution + + +class RagPipelineTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, task_id: str): + """ + Stop workflow task + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) + + return {"result": "success"} + + +class PublishedRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_fields) + def get(self, pipeline: Pipeline): + """ + Get published pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + if not pipeline.is_published: + return None + # fetch published workflow by pipeline + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) + + # return workflow, if not found, return None + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Publish workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + rag_pipeline_service = RagPipelineService() + with Session(db.engine) as session: + pipeline = session.merge(pipeline) + workflow = rag_pipeline_service.publish_workflow( + session=session, + pipeline=pipeline, + account=current_user, + ) + pipeline.is_published = True + pipeline.workflow_id = workflow.id + session.add(pipeline) + workflow_created_at = TimestampField().format(workflow.created_at) + + session.commit() + + return { + "result": "success", + "created_at": workflow_created_at, + } + + +class DefaultRagPipelineBlockConfigsApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get default block config + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + # Get default block configs + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.get_default_block_configs() + + +class DefaultRagPipelineBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline, block_type: str): + """ + Get default block config + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("q", type=str, location="args") + args = parser.parse_args() + + q = args.get("q") + + filters = None + if q: + try: + filters = json.loads(args.get("q", "")) + except json.JSONDecodeError: + raise ValueError("Invalid filters") + + # Get default block configs + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) + + +class RagPipelineConfigApi(Resource): + """Resource for rag pipeline configuration.""" + + @setup_required + @login_required + @account_initialization_required + def get(self, pipeline_id): + return { + "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + } + + +class PublishedAllRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_pagination_fields) + def get(self, pipeline: Pipeline): + """ + Get published workflows + """ + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("user_id", type=str, required=False, location="args") + parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") + args = parser.parse_args() + page = int(args.get("page", 1)) + limit = int(args.get("limit", 10)) + user_id = args.get("user_id") + named_only = args.get("named_only", False) + + if user_id: + if user_id != current_user.id: + raise Forbidden() + user_id = cast(str, user_id) + + rag_pipeline_service = RagPipelineService() + with Session(db.engine) as session: + workflows, has_more = rag_pipeline_service.get_all_published_workflow( + session=session, + pipeline=pipeline, + page=page, + limit=limit, + user_id=user_id, + named_only=named_only, + ) + + return { + "items": workflows, + "page": page, + "limit": limit, + "has_more": has_more, + } + + +class RagPipelineByIdApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_fields) + def patch(self, pipeline: Pipeline, workflow_id: str): + """ + Update workflow attributes + """ + # Check permission + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("marked_name", type=str, required=False, location="json") + parser.add_argument("marked_comment", type=str, required=False, location="json") + args = parser.parse_args() + + # Validate name and comment length + if args.marked_name and len(args.marked_name) > 20: + raise ValueError("Marked name cannot exceed 20 characters") + if args.marked_comment and len(args.marked_comment) > 100: + raise ValueError("Marked comment cannot exceed 100 characters") + args = parser.parse_args() + + # Prepare update data + update_data = {} + if args.get("marked_name") is not None: + update_data["marked_name"] = args["marked_name"] + if args.get("marked_comment") is not None: + update_data["marked_comment"] = args["marked_comment"] + + if not update_data: + return {"message": "No valid fields to update"}, 400 + + rag_pipeline_service = RagPipelineService() + + # Create a session and manage the transaction + with Session(db.engine, expire_on_commit=False) as session: + workflow = rag_pipeline_service.update_workflow( + session=session, + workflow_id=workflow_id, + tenant_id=pipeline.tenant_id, + account_id=current_user.id, + data=update_data, + ) + + if not workflow: + raise NotFound("Workflow not found") + + # Commit the transaction in the controller + session.commit() + + return workflow + + +class PublishedRagPipelineSecondStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get second step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) + return { + "variables": variables, + } + + +class PublishedRagPipelineFirstStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get first step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) + return { + "variables": variables, + } + + +class DraftRagPipelineFirstStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get first step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) + return { + "variables": variables, + } + + +class DraftRagPipelineSecondStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get second step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) + return { + "variables": variables, + } + + +class RagPipelineWorkflowRunListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_pagination_fields) + def get(self, pipeline: Pipeline): + """ + Get workflow run list + """ + parser = reqparse.RequestParser() + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + args = parser.parse_args() + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) + + return result + + +class RagPipelineWorkflowRunDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_detail_fields) + def get(self, pipeline: Pipeline, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + rag_pipeline_service = RagPipelineService() + workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id) + + return workflow_run + + +class RagPipelineWorkflowRunNodeExecutionListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_node_execution_list_fields) + def get(self, pipeline: Pipeline, run_id): + """ + Get workflow run node execution list + """ + run_id = str(run_id) + + rag_pipeline_service = RagPipelineService() + user = cast("Account | EndUser", current_user) + node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( + pipeline=pipeline, + run_id=run_id, + user=user, + ) + + return {"data": node_executions} + + +class DatasourceListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user = current_user + + tenant_id = user.current_tenant_id + + return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) + + +api.add_resource( + DraftRagPipelineApi, + "/rag/pipelines//workflows/draft", +) +api.add_resource( + RagPipelineConfigApi, + "/rag/pipelines//workflows/draft/config", +) +api.add_resource( + DraftRagPipelineRunApi, + "/rag/pipelines//workflows/draft/run", +) +api.add_resource( + PublishedRagPipelineRunApi, + "/rag/pipelines//workflows/published/run", +) +api.add_resource( + RagPipelineTaskStopApi, + "/rag/pipelines//workflow-runs/tasks//stop", +) +api.add_resource( + RagPipelineDraftNodeRunApi, + "/rag/pipelines//workflows/draft/nodes//run", +) +api.add_resource( + RagPipelinePublishedDatasourceNodeRunApi, + "/rag/pipelines//workflows/published/datasource/nodes//run", +) +# api.add_resource( +# RagPipelinePublishedDatasourceNodeRunStatusApi, +# "/rag/pipelines//workflows/published/datasource/nodes//run-status", +# ) +# api.add_resource( +# RagPipelineDraftDatasourceNodeRunStatusApi, +# "/rag/pipelines//workflows/draft/datasource/nodes//run-status", +# ) + +api.add_resource( + RagPipelineDraftDatasourceNodeRunApi, + "/rag/pipelines//workflows/draft/datasource/nodes//run", +) + +api.add_resource( + RagPipelineDraftRunIterationNodeApi, + "/rag/pipelines//workflows/draft/iteration/nodes//run", +) + +api.add_resource( + RagPipelinePublishedNodeRunApi, + "/rag/pipelines//workflows/published/nodes//run", +) + +api.add_resource( + RagPipelineDraftRunLoopNodeApi, + "/rag/pipelines//workflows/draft/loop/nodes//run", +) + +api.add_resource( + PublishedRagPipelineApi, + "/rag/pipelines//workflows/publish", +) +api.add_resource( + PublishedAllRagPipelineApi, + "/rag/pipelines//workflows", +) +api.add_resource( + DefaultRagPipelineBlockConfigsApi, + "/rag/pipelines//workflows/default-workflow-block-configs", +) +api.add_resource( + DefaultRagPipelineBlockConfigApi, + "/rag/pipelines//workflows/default-workflow-block-configs/", +) +api.add_resource( + RagPipelineByIdApi, + "/rag/pipelines//workflows/", +) +api.add_resource( + RagPipelineWorkflowRunListApi, + "/rag/pipelines//workflow-runs", +) +api.add_resource( + RagPipelineWorkflowRunDetailApi, + "/rag/pipelines//workflow-runs/", +) +api.add_resource( + RagPipelineWorkflowRunNodeExecutionListApi, + "/rag/pipelines//workflow-runs//node-executions", +) +api.add_resource( + DatasourceListApi, + "/rag/pipelines/datasource-plugins", +) +api.add_resource( + PublishedRagPipelineSecondStepApi, + "/rag/pipelines//workflows/published/processing/parameters", +) +api.add_resource( + PublishedRagPipelineFirstStepApi, + "/rag/pipelines//workflows/published/pre-processing/parameters", +) +api.add_resource( + DraftRagPipelineSecondStepApi, + "/rag/pipelines//workflows/draft/processing/parameters", +) +api.add_resource( + DraftRagPipelineFirstStepApi, + "/rag/pipelines//workflows/draft/pre-processing/parameters", +) diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py new file mode 100644 index 0000000000..32fd47fd36 --- /dev/null +++ b/api/controllers/console/datasets/wraps.py @@ -0,0 +1,43 @@ +from collections.abc import Callable +from functools import wraps +from typing import Optional + +from controllers.console.datasets.error import PipelineNotFoundError +from extensions.ext_database import db +from libs.login import current_user +from models.dataset import Pipeline + + +def get_rag_pipeline( + view: Optional[Callable] = None, +): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): + if not kwargs.get("pipeline_id"): + raise ValueError("missing pipeline_id in path parameters") + + pipeline_id = kwargs.get("pipeline_id") + pipeline_id = str(pipeline_id) + + del kwargs["pipeline_id"] + + pipeline = ( + db.session.query(Pipeline) + .filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) + .first() + ) + + if not pipeline: + raise PipelineNotFoundError() + + kwargs["pipeline"] = pipeline + + return view_func(*args, **kwargs) + + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 75bd2f677a..cbb382beb3 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -113,9 +113,9 @@ class VariableEntity(BaseModel): hide: bool = False max_length: Optional[int] = None options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list) + allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list) + allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list) @field_validator("description", mode="before") @classmethod @@ -128,6 +128,16 @@ class VariableEntity(BaseModel): return v or [] +class RagPipelineVariableEntity(VariableEntity): + """ + Rag Pipeline Variable Entity. + """ + + tooltips: Optional[str] = None + placeholder: Optional[str] = None + belong_to_node_id: str + + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. @@ -285,7 +295,7 @@ class AppConfig(BaseModel): tenant_id: str app_id: str app_mode: AppMode - additional_features: AppAdditionalFeatures + additional_features: Optional[AppAdditionalFeatures] = None variables: list[VariableEntity] = [] sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 2f1da38082..1c63874ee3 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,4 +1,4 @@ -from core.app.app_config.entities import VariableEntity +from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from models.workflow import Workflow @@ -20,3 +20,19 @@ class WorkflowVariablesConfigManager: variables.append(VariableEntity.model_validate(variable)) return variables + + @classmethod + def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]: + """ + Convert workflow start variables to variables + + :param workflow: workflow instance + """ + variables = [] + + user_input_form = workflow.rag_pipeline_user_input_form() + # variables + for variable in user_input_form: + variables.append(RagPipelineVariableEntity.model_validate(variable)) + + return variables diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 34a1da2227..fe7abddf87 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -43,11 +43,13 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.file import FILE_MODEL_IDENTITY, File +from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.tool_manager import ToolManager from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.nodes import NodeType +from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( @@ -183,6 +185,14 @@ class WorkflowResponseConverter: provider_type=node_data.provider_type, provider_id=node_data.provider_id, ) + elif event.node_type == NodeType.DATASOURCE: + node_data = cast(DatasourceNodeData, event.node_data) + manager = PluginDatasourceManager() + provider_entity = manager.fetch_datasource_provider( + self._application_generate_entity.app_config.tenant_id, + f"{node_data.plugin_id}/{node_data.provider_name}", + ) + response.data.extras["icon"] = provider_entity.declaration.identity.icon return response diff --git a/api/core/app/apps/pipeline/__init__.py b/api/core/app/apps/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py new file mode 100644 index 0000000000..10ec73a7d2 --- /dev/null +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -0,0 +1,95 @@ +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) + + +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = WorkflowAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + return dict(blocking_response.to_dict()) + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + return cls.convert_blocking_full_response(blocking_response) + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield response_chunk + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield response_chunk diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py new file mode 100644 index 0000000000..b83fc1800f --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -0,0 +1,64 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.dataset import Pipeline +from models.model import AppMode +from models.workflow import Workflow + + +class PipelineConfig(WorkflowUIBasedAppConfig): + """ + Pipeline Config Entity. + """ + + rag_pipeline_variables: list[RagPipelineVariableEntity] = [] + pass + + +class PipelineConfigManager(BaseAppConfigManager): + @classmethod + def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig: + pipeline_config = PipelineConfig( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + app_mode=AppMode.RAG_PIPELINE, + workflow_id=workflow.id, + rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow), + ) + + return pipeline_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for pipeline config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py new file mode 100644 index 0000000000..13acc4ef38 --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -0,0 +1,621 @@ +import contextvars +import datetime +import json +import logging +import secrets +import threading +import time +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Optional, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError +from sqlalchemy.orm import sessionmaker + +import contexts +from configs import dify_config +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager +from core.app.apps.pipeline.pipeline_runner import PipelineRunner +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity +from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.entities.knowledge_entities import PipelineDataset, PipelineDocument +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from extensions.ext_database import db +from libs.flask_utils import preserve_flask_contexts +from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline +from models.enums import WorkflowRunTriggeredFrom +from models.model import AppMode +from services.dataset_service import DocumentService + +logger = logging.getLogger(__name__) + + +class PipelineGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... + + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... + + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: + # convert to app config + pipeline_config = PipelineConfigManager.get_pipeline_config( + pipeline=pipeline, + workflow=workflow, + ) + # Add null check for dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") + inputs: Mapping[str, Any] = args["inputs"] + start_node_id: str = args["start_node_id"] + datasource_type: str = args["datasource_type"] + datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] + batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) + documents = [] + if invoke_from == InvokeFrom.PUBLISHED: + for datasource_info in datasource_info_list: + position = DocumentService.get_documents_position(dataset.id) + document = self._build_document( + tenant_id=pipeline.tenant_id, + dataset_id=dataset.id, + built_in_field_enabled=dataset.built_in_field_enabled, + datasource_type=datasource_type, + datasource_info=datasource_info, + created_from="rag-pipeline", + position=position, + account=user, + batch=batch, + document_form=dataset.chunk_structure, + ) + db.session.add(document) + documents.append(document) + db.session.commit() + + # run in child thread + for i, datasource_info in enumerate(datasource_info_list): + workflow_run_id = str(uuid.uuid4()) + document_id = None + if invoke_from == InvokeFrom.PUBLISHED: + document_id = documents[i].id + document_pipeline_execution_log = DocumentPipelineExecutionLog( + document_id=document_id, + datasource_type=datasource_type, + datasource_info=json.dumps(datasource_info), + datasource_node_id=start_node_id, + input_data=inputs, + pipeline_id=pipeline.id, + created_by=user.id, + ) + db.session.add(document_pipeline_execution_log) + db.session.commit() + application_generate_entity = RagPipelineGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=datasource_type, + datasource_info=datasource_info, + dataset_id=dataset.id, + start_node_id=start_node_id, + batch=batch, + document_id=document_id, + inputs=self._prepare_user_inputs( + user_inputs=inputs, + variables=pipeline_config.rag_pipeline_variables, + tenant_id=pipeline.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ), + files=[], + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + workflow_execution_id=workflow_run_id, + ) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + if invoke_from == InvokeFrom.DEBUGGER: + return self._generate( + flask_app=current_app._get_current_object(), # type: ignore + context=contextvars.copy_context(), + pipeline=pipeline, + workflow_id=workflow.id, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + else: + # run in child thread + context = contextvars.copy_context() + + worker_thread = threading.Thread( + target=self._generate, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "pipeline": pipeline, + "workflow_id": workflow.id, + "user": user, + "application_generate_entity": application_generate_entity, + "invoke_from": invoke_from, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "streaming": streaming, + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) + + worker_thread.start() + # return batch, dataset, documents + return { + "batch": batch, + "dataset": PipelineDataset( + id=dataset.id, + name=dataset.name, + description=dataset.description, + chunk_structure=dataset.chunk_structure, + ).model_dump(), + "documents": [ + PipelineDocument( + id=document.id, + position=document.position, + data_source_type=document.data_source_type, + data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, + name=document.name, + indexing_status=document.indexing_status, + error=document.error, + enabled=document.enabled, + ).model_dump() + for document in documents + ], + } + + def _generate( + self, + *, + flask_app: Flask, + context: contextvars.Context, + pipeline: Pipeline, + workflow_id: str, + user: Union[Account, EndUser], + application_generate_entity: RagPipelineGenerateEntity, + invoke_from: InvokeFrom, + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + streaming: bool = True, + workflow_thread_pool_id: Optional[str] = None, + ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: + """ + Generate App response. + + :param pipeline: Pipeline + :param workflow: Workflow + :param user: account or end user + :param application_generate_entity: application generate entity + :param invoke_from: invoke from source + :param workflow_execution_repository: repository for workflow execution + :param workflow_node_execution_repository: repository for workflow node execution + :param streaming: is stream + :param workflow_thread_pool_id: workflow thread pool id + """ + with preserve_flask_contexts(flask_app, context_vars=context): + # init queue manager + workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first() + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + queue_manager = PipelineQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=AppMode.RAG_PIPELINE, + ) + context = contextvars.copy_context() + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "queue_manager": queue_manager, + "application_generate_entity": application_generate_entity, + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def single_iteration_generate( + self, + pipeline: Pipeline, + workflow: Workflow, + node_id: str, + user: Account | EndUser, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param node_id: the node id + :param user: account or end user + :param args: request args + :param streaming: is streamed + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") + + # init application generate entity - use RagPipelineGenerateEntity instead + application_generate_entity = RagPipelineGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=args.get("datasource_type", ""), + datasource_info=args.get("datasource_info", {}), + dataset_id=dataset.id, + batch=args.get("batch", ""), + document_id=args.get("document_id"), + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + workflow_execution_id=str(uuid.uuid4()), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + return self._generate( + flask_app=current_app._get_current_object(), # type: ignore + pipeline=pipeline, + workflow_id=workflow.id, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + ) + + def single_loop_generate( + self, + pipeline: Pipeline, + workflow: Workflow, + node_id: str, + user: Account | EndUser, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param node_id: the node id + :param user: account or end user + :param args: request args + :param streaming: is streamed + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") + + # convert to app config + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + + # init application generate entity + application_generate_entity = RagPipelineGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=args.get("datasource_type", ""), + datasource_info=args.get("datasource_info", {}), + batch=args.get("batch", ""), + document_id=args.get("document_id"), + dataset_id=dataset.id, + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + workflow_execution_id=str(uuid.uuid4()), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, + ) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + return self._generate( + flask_app=current_app._get_current_object(), # type: ignore + pipeline=pipeline, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + ) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: RagPipelineGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id + :return: + """ + + with preserve_flask_contexts(flask_app, context_vars=context): + try: + # workflow app + runner = PipelineRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + runner.run() + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def _handle_response( + self, + application_generate_entity: RagPipelineGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param user: account or end user + :param stream: is stream + :param workflow_node_execution_repository: optional repository for workflow node execution + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=stream, + workflow_node_execution_repository=workflow_node_execution_repository, + workflow_execution_repository=workflow_execution_repository, + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedError() + else: + logger.exception( + f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" + ) + raise e + + def _build_document( + self, + tenant_id: str, + dataset_id: str, + built_in_field_enabled: bool, + datasource_type: str, + datasource_info: Mapping[str, Any], + created_from: str, + position: int, + account: Union[Account, EndUser], + batch: str, + document_form: str, + ): + if datasource_type == "local_file": + name = datasource_info["name"] + elif datasource_type == "online_document": + name = datasource_info["page"]["page_name"] + elif datasource_type == "website_crawl": + name = datasource_info["title"] + else: + raise ValueError(f"Unsupported datasource type: {datasource_type}") + + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=position, + data_source_type=datasource_type, + data_source_info=json.dumps(datasource_info), + batch=batch, + name=name, + created_from=created_from, + created_by=account.id, + doc_form=document_form, + ) + doc_metadata = {} + if built_in_field_enabled: + doc_metadata = { + BuiltInField.document_name: name, + BuiltInField.uploader: account.name, + BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.source: datasource_type, + } + if doc_metadata: + document.doc_metadata = doc_metadata + return document diff --git a/api/core/app/apps/pipeline/pipeline_queue_manager.py b/api/core/app/apps/pipeline/pipeline_queue_manager.py new file mode 100644 index 0000000000..d0aeac8a9c --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_queue_manager.py @@ -0,0 +1,44 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, + WorkflowQueueMessage, +) + + +class PipelineQueueManager(AppQueueManager): + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) + + self._q.put(message) + + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, + ): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py new file mode 100644 index 0000000000..52afb78ee5 --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -0,0 +1,221 @@ +import logging +from collections.abc import Mapping +from typing import Any, Optional, cast + +from configs import dify_config +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + RagPipelineGenerateEntity, +) +from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.dataset import Pipeline +from models.enums import UserFrom +from models.model import EndUser +from models.workflow import Workflow, WorkflowType + +logger = logging.getLogger(__name__) + + +class PipelineRunner(WorkflowBasedAppRunner): + """ + Pipeline Application Runner + """ + + def __init__( + self, + application_generate_entity: RagPipelineGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + + def run(self) -> None: + """ + Run application + """ + app_config = self.application_generate_entity.app_config + app_config = cast(PipelineConfig, app_config) + + user_id = None + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id + + pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + + workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + db.session.close() + + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: + workflow_callbacks.append(WorkflowLoggingCallback()) + + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + elif self.application_generate_entity.single_loop_run: + # if only single loop run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( + workflow=workflow, + node_id=self.application_generate_entity.single_loop_run.node_id, + user_inputs=self.application_generate_entity.single_loop_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, + SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, + SystemVariableKey.BATCH: self.application_generate_entity.batch, + SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, + SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, + SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, + SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, + } + rag_pipeline_variables = [] + if workflow.rag_pipeline_variables: + for v in workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if ( + rag_pipeline_variable.belong_to_node_id + in (self.application_generate_entity.start_node_id, "shared") + ) and rag_pipeline_variable.variable in inputs: + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=inputs[rag_pipeline_variable.variable], + ) + ) + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + rag_pipeline_variables=rag_pipeline_variables, + ) + + # init graph + graph = self._init_rag_pipeline_graph( + graph_config=workflow.graph_dict, + start_node_id=self.application_generate_entity.start_node_id, + ) + + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id, + ) + + generator = workflow_entry.run(callbacks=workflow_callbacks) + + for event in generator: + self._handle_event(workflow_entry, event) + + def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph: + """ + Init pipeline graph + """ + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + nodes = graph_config.get("nodes", []) + edges = graph_config.get("edges", []) + real_run_nodes = [] + real_edges = [] + exclude_node_ids = [] + for node in nodes: + node_id = node.get("id") + node_type = node.get("data", {}).get("type", "") + if node_type == "datasource": + if start_node_id != node_id: + exclude_node_ids.append(node_id) + continue + real_run_nodes.append(node) + for edge in edges: + if edge.get("source") in exclude_node_ids: + continue + real_edges.append(edge) + graph_config = dict(graph_config) + graph_config["nodes"] = real_run_nodes + graph_config["edges"] = real_edges + # init graph + graph = Graph.init(graph_config=graph_config) + + if not graph: + raise ValueError("graph not found in workflow") + + return graph diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 65ed267959..4947861ef0 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -36,6 +36,7 @@ class InvokeFrom(Enum): # DEBUGGER indicates that this invocation is from # the workflow (or chatflow) edit page. DEBUGGER = "debugger" + PUBLISHED = "published" @classmethod def value_of(cls, value: str): @@ -240,3 +241,38 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): inputs: dict single_loop_run: Optional[SingleLoopRunEntity] = None + + +class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): + """ + RAG Pipeline Application Generate Entity. + """ + + # pipeline config + pipeline_config: WorkflowUIBasedAppConfig + datasource_type: str + datasource_info: Mapping[str, Any] + dataset_id: str + batch: str + document_id: Optional[str] = None + start_node_id: Optional[str] = None + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None + + class SingleLoopRunEntity(BaseModel): + """ + Single Loop Run Entity. + """ + + node_id: str + inputs: dict + + single_loop_run: Optional[SingleLoopRunEntity] = None diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 65d899a002..1063e66c59 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel): self.current_loop += 1 + def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None: + """Run on datasource start.""" + if dify_config.DEBUG: + print_text( + "\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n", + color=self.color, + ) + @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py new file mode 100644 index 0000000000..5a13d17843 --- /dev/null +++ b/api/core/datasource/__base/datasource_plugin.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) + + +class DatasourcePlugin(ABC): + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + ) -> None: + self.entity = entity + self.runtime = runtime + + @abstractmethod + def datasource_provider_type(self) -> str: + """ + returns the type of the datasource provider + """ + return DatasourceProviderType.LOCAL_FILE + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return self.__class__( + entity=self.entity.model_copy(), + runtime=runtime, + ) diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py new file mode 100644 index 0000000000..bae39dc8c7 --- /dev/null +++ b/api/core/datasource/__base/datasource_provider.py @@ -0,0 +1,118 @@ +from abc import ABC, abstractmethod +from typing import Any + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.entities.provider_entities import ProviderConfig +from core.plugin.impl.tool import PluginToolManager +from core.tools.errors import ToolProviderCredentialValidationError + + +class DatasourcePluginProviderController(ABC): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + + def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: + self.entity = entity + self.tenant_id = tenant_id + + @property + def need_credentials(self) -> bool: + """ + returns whether the provider needs credentials + + :return: whether the provider needs credentials + """ + return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + manager = PluginToolManager() + if not manager.validate_datasource_credentials( + tenant_id=self.tenant_id, + user_id=user_id, + provider=self.entity.identity.name, + credentials=credentials, + ): + raise ToolProviderCredentialValidationError("Invalid credentials") + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.LOCAL_FILE + + @abstractmethod + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: + """ + return datasource with given name + """ + pass + + def validate_credentials_format(self, credentials: dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = dict[str, ProviderConfig]() + if credentials_schema is None: + return + + for credential in self.entity.credentials_schema: + credentials_schema[credential.name] = credential + + credentials_need_to_validate: dict[str, ProviderConfig] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.entity.identity.name}" + ) + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if not credential_schema.required and credentials[credential_name] is None: + continue + + if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + elif credential_schema.type == ProviderConfig.Type.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type in { + ProviderConfig.Type.SECRET_INPUT, + ProviderConfig.Type.TEXT_INPUT, + ProviderConfig.Type.SELECT, + }: + default_value = str(default_value) + + credentials[credential_name] = default_value diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py new file mode 100644 index 0000000000..264145d261 --- /dev/null +++ b/api/core/datasource/__base/datasource_runtime.py @@ -0,0 +1,36 @@ +from typing import Any, Optional + +from openai import BaseModel +from pydantic import Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import DatasourceInvokeFrom + + +class DatasourceRuntime(BaseModel): + """ + Meta data of a datasource call processing + """ + + tenant_id: str + datasource_id: Optional[str] = None + invoke_from: Optional[InvokeFrom] = None + datasource_invoke_from: Optional[DatasourceInvokeFrom] = None + credentials: dict[str, Any] = Field(default_factory=dict) + runtime_parameters: dict[str, Any] = Field(default_factory=dict) + + +class FakeDatasourceRuntime(DatasourceRuntime): + """ + Fake datasource runtime for testing + """ + + def __init__(self): + super().__init__( + tenant_id="fake_tenant_id", + datasource_id="fake_datasource_id", + invoke_from=InvokeFrom.DEBUGGER, + datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, + credentials={}, + runtime_parameters={}, + ) diff --git a/api/core/datasource/__init__.py b/api/core/datasource/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py new file mode 100644 index 0000000000..51296b64d2 --- /dev/null +++ b/api/core/datasource/datasource_file_manager.py @@ -0,0 +1,244 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from mimetypes import guess_extension, guess_type +from typing import Optional, Union +from uuid import uuid4 + +import httpx + +from configs import dify_config +from core.helper import ssrf_proxy +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.enums import CreatorUserRole +from models.model import MessageFile, UploadFile +from models.tools import ToolFile + +logger = logging.getLogger(__name__) + + +class DatasourceFileManager: + @staticmethod + def sign_file(datasource_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @staticmethod + def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def create_file_by_raw( + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, + filename: Optional[str] = None, + ) -> UploadFile: + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + unique_filename = f"{unique_name}{extension}" + # default just as before + present_filename = unique_filename + if filename is not None: + has_extension = len(filename.split(".")) > 1 + # Add extension flexibly + present_filename = filename if has_extension else f"{filename}{extension}" + filepath = f"datasources/{tenant_id}/{unique_filename}" + storage.save(filepath, file_binary) + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=filepath, + name=present_filename, + size=len(file_binary), + extension=extension, + mime_type=mimetype, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=user_id, + used=False, + hash=hashlib.sha3_256(file_binary).hexdigest(), + source_url="", + ) + + db.session.add(upload_file) + db.session.commit() + db.session.refresh(upload_file) + + return upload_file + + @staticmethod + def create_file_by_url( + user_id: str, + tenant_id: str, + file_url: str, + conversation_id: Optional[str] = None, + ) -> UploadFile: + # try to download image + try: + response = ssrf_proxy.get(file_url) + response.raise_for_status() + blob = response.content + except httpx.TimeoutException: + raise ValueError(f"timeout when downloading file from {file_url}") + + mimetype = ( + guess_type(file_url)[0] + or response.headers.get("Content-Type", "").split(";")[0].strip() + or "application/octet-stream" + ) + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=filepath, + name=filename, + size=len(blob), + extension=extension, + mime_type=mimetype, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=user_id, + used=False, + hash=hashlib.sha3_256(blob).hexdigest(), + source_url=file_url, + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + upload_file: UploadFile | None = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == id, + ) + .first() + ) + + if not upload_file: + return None + + blob = storage.load_once(upload_file.key) + + return blob, upload_file.mime_type + + @staticmethod + def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile | None = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) + + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None + else: + tool_file_id = None + + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_generator_by_upload_file_id(upload_file_id: str): + """ + get file binary + + :param tool_file_id: the id of the tool file + + :return: the binary of the file, mime type + """ + upload_file: UploadFile | None = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == upload_file_id, + ) + .first() + ) + + if not upload_file: + return None, None + + stream = storage.load_stream(upload_file.key) + + return stream, upload_file.mime_type + + +# init tool_file_parser +# from core.file.datasource_file_parser import datasource_file_manager +# +# datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py new file mode 100644 index 0000000000..838fee5b96 --- /dev/null +++ b/api/core/datasource/datasource_manager.py @@ -0,0 +1,100 @@ +import logging +from threading import Lock +from typing import Union + +import contexts +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.entities.common_entities import I18nObject +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.datasource.errors import DatasourceProviderNotFoundError +from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController +from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController +from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController +from core.plugin.impl.datasource import PluginDatasourceManager + +logger = logging.getLogger(__name__) + + +class DatasourceManager: + _builtin_provider_lock = Lock() + _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {} + _builtin_providers_loaded = False + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + + @classmethod + def get_datasource_plugin_provider( + cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType + ) -> DatasourcePluginProviderController: + """ + get the datasource plugin provider + """ + # check if context is set + try: + contexts.datasource_plugin_providers.get() + except LookupError: + contexts.datasource_plugin_providers.set({}) + contexts.datasource_plugin_providers_lock.set(Lock()) + + with contexts.datasource_plugin_providers_lock.get(): + datasource_plugin_providers = contexts.datasource_plugin_providers.get() + if provider_id in datasource_plugin_providers: + return datasource_plugin_providers[provider_id] + + manager = PluginDatasourceManager() + provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) + if not provider_entity: + raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") + + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + controller = OnlineDocumentDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case DatasourceProviderType.WEBSITE_CRAWL: + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case DatasourceProviderType.LOCAL_FILE: + controller = LocalFileDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") + + datasource_plugin_providers[provider_id] = controller + + return controller + + @classmethod + def get_datasource_runtime( + cls, + provider_id: str, + datasource_name: str, + tenant_id: str, + datasource_type: DatasourceProviderType, + ) -> DatasourcePlugin: + """ + get the datasource runtime + + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :param datasource_name: the name of the datasource + :param tenant_id: the tenant id + + :return: the datasource plugin + """ + return cls.get_datasource_plugin_provider( + provider_id, + tenant_id, + datasource_type, + ).get_datasource(datasource_name) diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py new file mode 100644 index 0000000000..81771719ea --- /dev/null +++ b/api/core/datasource/entities/api_entities.py @@ -0,0 +1,71 @@ +from typing import Literal, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.datasource.entities.datasource_entities import DatasourceParameter +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.common_entities import I18nObject + + +class DatasourceApiEntity(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[list[DatasourceParameter]] = None + labels: list[str] = Field(default_factory=list) + output_schema: Optional[dict] = None + + +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] + + +class DatasourceProviderApiEntity(BaseModel): + id: str + author: str + name: str # identifier + description: I18nObject + icon: str | dict + label: I18nObject # label + type: str + masked_credentials: Optional[dict] = None + original_credentials: Optional[dict] = None + is_team_authorization: bool = False + allow_delete: bool = True + plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") + datasources: list[DatasourceApiEntity] = Field(default_factory=list) + labels: list[str] = Field(default_factory=list) + + @field_validator("datasources", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + + def to_dict(self) -> dict: + # ------------- + # overwrite datasource parameter types for temp fix + datasources = jsonable_encoder(self.datasources) + for datasource in datasources: + if datasource.get("parameters"): + for parameter in datasource.get("parameters"): + if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: + parameter["type"] = "files" + # ------------- + + return { + "id": self.id, + "author": self.author, + "name": self.name, + "plugin_id": self.plugin_id, + "plugin_unique_identifier": self.plugin_unique_identifier, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "datasources": datasources, + "labels": self.labels, + } diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py new file mode 100644 index 0000000000..924e6fc0cf --- /dev/null +++ b/api/core/datasource/entities/common_entities.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class I18nObject(BaseModel): + """ + Model class for i18n object. + """ + + en_US: str + zh_Hans: Optional[str] = Field(default=None) + pt_BR: Optional[str] = Field(default=None) + ja_JP: Optional[str] = Field(default=None) + + def __init__(self, **data): + super().__init__(**data) + self.zh_Hans = self.zh_Hans or self.en_US + self.pt_BR = self.pt_BR or self.en_US + self.ja_JP = self.ja_JP or self.en_US + + def to_dict(self) -> dict: + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py new file mode 100644 index 0000000000..2c3de1e5d7 --- /dev/null +++ b/api/core/datasource/entities/datasource_entities.py @@ -0,0 +1,361 @@ +import enum +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, Field, ValidationInfo, field_validator + +from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.oauth import OAuthSchema +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum + + +class DatasourceProviderType(enum.StrEnum): + """ + Enum class for datasource provider + """ + + ONLINE_DOCUMENT = "online_document" + LOCAL_FILE = "local_file" + WEBSITE_CRAWL = "website_crawl" + ONLINE_DRIVE = "online_drive" + + @classmethod + def value_of(cls, value: str) -> "DatasourceProviderType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class DatasourceParameter(PluginParameter): + """ + Overrides type + """ + + class DatasourceParameterType(enum.StrEnum): + """ + removes TOOLS_SELECTOR from PluginParameterType + """ + + STRING = PluginParameterType.STRING.value + NUMBER = PluginParameterType.NUMBER.value + BOOLEAN = PluginParameterType.BOOLEAN.value + SELECT = PluginParameterType.SELECT.value + SECRET_INPUT = PluginParameterType.SECRET_INPUT.value + FILE = PluginParameterType.FILE.value + FILES = PluginParameterType.FILES.value + + # deprecated, should not use. + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + + def as_normal_type(self): + return as_normal_type(self) + + def cast_value(self, value: Any): + return cast_parameter_value(self, value) + + type: DatasourceParameterType = Field(..., description="The type of the parameter") + description: I18nObject = Field(..., description="The description of the parameter") + + @classmethod + def get_simple_instance( + cls, + name: str, + typ: DatasourceParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "DatasourceParameter": + """ + get a simple datasource parameter + + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param typ: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter + """ + # convert options to ToolParameterOption + # FIXME fix the type error + if options: + option_objs = [ + PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ] + else: + option_objs = [] + + return cls( + name=name, + label=I18nObject(en_US="", zh_Hans=""), + placeholder=None, + type=typ, + required=required, + options=option_objs, + description=I18nObject(en_US="", zh_Hans=""), + ) + + def init_frontend_parameter(self, value: Any): + return init_frontend_parameter(self, self.type, value) + + +class DatasourceIdentity(BaseModel): + author: str = Field(..., description="The author of the datasource") + name: str = Field(..., description="The name of the datasource") + label: I18nObject = Field(..., description="The label of the datasource") + provider: str = Field(..., description="The provider of the datasource") + icon: Optional[str] = None + + +class DatasourceEntity(BaseModel): + identity: DatasourceIdentity + parameters: list[DatasourceParameter] = Field(default_factory=list) + description: I18nObject = Field(..., description="The label of the datasource") + + @field_validator("parameters", mode="before") + @classmethod + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: + return v or [] + + +class DatasourceProviderIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + description: I18nObject = Field(..., description="The description of the tool") + icon: str = Field(..., description="The icon of the tool") + label: I18nObject = Field(..., description="The label of the tool") + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + + +class DatasourceProviderEntity(BaseModel): + """ + Datasource provider entity + """ + + identity: DatasourceProviderIdentity + credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None + provider_type: DatasourceProviderType + + +class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): + datasources: list[DatasourceEntity] = Field(default_factory=list) + + +class DatasourceInvokeMeta(BaseModel): + """ + Datasource invoke meta + """ + + time_cost: float = Field(..., description="The time cost of the tool invoke") + error: Optional[str] = None + tool_config: Optional[dict] = None + + @classmethod + def empty(cls) -> "DatasourceInvokeMeta": + """ + Get an empty instance of DatasourceInvokeMeta + """ + return cls(time_cost=0.0, error=None, tool_config={}) + + @classmethod + def error_instance(cls, error: str) -> "DatasourceInvokeMeta": + """ + Get an instance of DatasourceInvokeMeta with error + """ + return cls(time_cost=0.0, error=error, tool_config={}) + + def to_dict(self) -> dict: + return { + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, + } + + +class DatasourceLabel(BaseModel): + """ + Datasource label + """ + + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + icon: str = Field(..., description="The icon of the tool") + + +class DatasourceInvokeFrom(Enum): + """ + Enum class for datasource invoke + """ + + RAG_PIPELINE = "rag_pipeline" + + +class OnlineDocumentPage(BaseModel): + """ + Online document page + """ + + page_id: str = Field(..., description="The page id") + page_name: str = Field(..., description="The page title") + page_icon: Optional[dict] = Field(None, description="The page icon") + type: str = Field(..., description="The type of the page") + last_edited_time: str = Field(..., description="The last edited time") + parent_id: Optional[str] = Field(None, description="The parent page id") + + +class OnlineDocumentInfo(BaseModel): + """ + Online document info + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") + + +class OnlineDocumentPagesMessage(BaseModel): + """ + Get online document pages response + """ + + result: list[OnlineDocumentInfo] + + +class GetOnlineDocumentPageContentRequest(BaseModel): + """ + Get online document page content request + """ + + workspace_id: str = Field(..., description="The workspace id") + page_id: str = Field(..., description="The page id") + type: str = Field(..., description="The type of the page") + + +class OnlineDocumentPageContent(BaseModel): + """ + Online document page content + """ + + workspace_id: str = Field(..., description="The workspace id") + page_id: str = Field(..., description="The page id") + content: str = Field(..., description="The content of the page") + + +class GetOnlineDocumentPageContentResponse(BaseModel): + """ + Get online document page content response + """ + + result: OnlineDocumentPageContent + + +class GetWebsiteCrawlRequest(BaseModel): + """ + Get website crawl request + """ + + crawl_parameters: dict = Field(..., description="The crawl parameters") + + +class WebSiteInfoDetail(BaseModel): + source_url: str = Field(..., description="The url of the website") + content: str = Field(..., description="The content of the website") + title: str = Field(..., description="The title of the website") + description: str = Field(..., description="The description of the website") + + +class WebSiteInfo(BaseModel): + """ + Website info + """ + + status: Optional[str] = Field(..., description="crawl job status") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + total: Optional[int] = Field(default=0, description="The total number of websites") + completed: Optional[int] = Field(default=0, description="The number of completed websites") + + +class WebsiteCrawlMessage(BaseModel): + """ + Get website crawl response + """ + + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) + + +class DatasourceMessage(ToolInvokeMessage): + pass + + +######################### +# Online driver file +######################### + + +class OnlineDriveFile(BaseModel): + """ + Online driver file + """ + + key: str = Field(..., description="The key of the file") + size: int = Field(..., description="The size of the file") + + +class OnlineDriveFileBucket(BaseModel): + """ + Online driver file bucket + """ + + bucket: Optional[str] = Field(None, description="The bucket of the file") + files: list[OnlineDriveFile] = Field(..., description="The files of the bucket") + is_truncated: bool = Field(False, description="Whether the bucket has more files") + + +class OnlineDriveBrowseFilesRequest(BaseModel): + """ + Get online driver file list request + """ + + prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'") + bucket: Optional[str] = Field(None, description="Storage bucket name") + max_keys: int = Field(20, description="Maximum number of files to return") + start_after: Optional[str] = Field( + None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'" + ) + + +class OnlineDriveBrowseFilesResponse(BaseModel): + """ + Get online driver file list response + """ + + result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files") + + +class OnlineDriveDownloadFileRequest(BaseModel): + """ + Get online driver file + """ + + key: str = Field(..., description="The name of the file") + bucket: Optional[str] = Field(None, description="The name of the bucket") diff --git a/api/core/datasource/errors.py b/api/core/datasource/errors.py new file mode 100644 index 0000000000..c7fc2f85b9 --- /dev/null +++ b/api/core/datasource/errors.py @@ -0,0 +1,37 @@ +from core.datasource.entities.datasource_entities import DatasourceInvokeMeta + + +class DatasourceProviderNotFoundError(ValueError): + pass + + +class DatasourceNotFoundError(ValueError): + pass + + +class DatasourceParameterValidationError(ValueError): + pass + + +class DatasourceProviderCredentialValidationError(ValueError): + pass + + +class DatasourceNotSupportedError(ValueError): + pass + + +class DatasourceInvokeError(ValueError): + pass + + +class DatasourceApiSchemaError(ValueError): + pass + + +class DatasourceEngineInvokeError(Exception): + meta: DatasourceInvokeMeta + + def __init__(self, meta, **kwargs): + self.meta = meta + super().__init__(**kwargs) diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py new file mode 100644 index 0000000000..82da10d663 --- /dev/null +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -0,0 +1,28 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) + + +class LocalFileDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.LOCAL_FILE diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py new file mode 100644 index 0000000000..b2b6f51dd3 --- /dev/null +++ b/api/core/datasource/local_file/local_file_provider.py @@ -0,0 +1,56 @@ +from typing import Any + +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin + + +class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity, tenant_id) + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.LOCAL_FILE + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + pass + + def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return LocalFileDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py new file mode 100644 index 0000000000..c1e015fd3a --- /dev/null +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -0,0 +1,73 @@ +from collections.abc import Generator, Mapping +from typing import Any + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceMessage, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + OnlineDocumentPagesMessage, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class OnlineDocumentDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def get_online_document_pages( + self, + user_id: str, + datasource_parameters: Mapping[str, Any], + provider_type: str, + ) -> Generator[OnlineDocumentPagesMessage, None, None]: + manager = PluginDatasourceManager() + + return manager.get_online_document_pages( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def get_online_document_page_content( + self, + user_id: str, + datasource_parameters: GetOnlineDocumentPageContentRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + manager = PluginDatasourceManager() + + return manager.get_online_document_page_content( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py new file mode 100644 index 0000000000..a128b479f4 --- /dev/null +++ b/api/core/datasource/online_document/online_document_provider.py @@ -0,0 +1,48 @@ +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin + + +class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity, tenant_id) + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.ONLINE_DOCUMENT + + def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return OnlineDocumentDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/online_drive/online_drive_plugin.py b/api/core/datasource/online_drive/online_drive_plugin.py new file mode 100644 index 0000000000..f0e3cb38f9 --- /dev/null +++ b/api/core/datasource/online_drive/online_drive_plugin.py @@ -0,0 +1,73 @@ +from collections.abc import Generator + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceMessage, + DatasourceProviderType, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class OnlineDriveDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def online_drive_browse_files( + self, + user_id: str, + request: OnlineDriveBrowseFilesRequest, + provider_type: str, + ) -> Generator[OnlineDriveBrowseFilesResponse, None, None]: + manager = PluginDatasourceManager() + + return manager.online_drive_browse_files( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def online_drive_download_file( + self, + user_id: str, + request: OnlineDriveDownloadFileRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + manager = PluginDatasourceManager() + + return manager.online_drive_download_file( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.ONLINE_DRIVE diff --git a/api/core/datasource/online_drive/online_drive_provider.py b/api/core/datasource/online_drive/online_drive_provider.py new file mode 100644 index 0000000000..d0923ed807 --- /dev/null +++ b/api/core/datasource/online_drive/online_drive_provider.py @@ -0,0 +1,48 @@ +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin + + +class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity, tenant_id) + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.ONLINE_DRIVE + + def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return OnlineDriveDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/utils/__init__.py b/api/core/datasource/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/utils/configuration.py b/api/core/datasource/utils/configuration.py new file mode 100644 index 0000000000..6a5fba65bd --- /dev/null +++ b/api/core/datasource/utils/configuration.py @@ -0,0 +1,265 @@ +from copy import deepcopy +from typing import Any + +from pydantic import BaseModel + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType +from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderType, +) + + +class ProviderConfigEncrypter(BaseModel): + tenant_id: str + config: list[BasicProviderConfig] + provider_type: str + provider_identity: str + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cached_credentials = cache.get() + if cached_credentials: + return cached_credentials + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + cache.set(data) + return data + + def delete_tool_credentials_cache(self): + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cache.delete() + + +class ToolParameterConfigurationManager: + """ + Tool parameter configuration manager + """ + + tenant_id: str + tool_runtime: Tool + provider_name: str + provider_type: ToolProviderType + identity_id: str + + def __init__( + self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str + ) -> None: + self.tenant_id = tenant_id + self.tool_runtime = tool_runtime + self.provider_name = provider_name + self.provider_type = provider_type + self.identity_id = identity_id + + def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + deep copy parameters + """ + return deepcopy(parameters) + + def _merge_parameters(self) -> list[ToolParameter]: + """ + merge parameters + """ + # get tool parameters + tool_parameters = self.tool_runtime.entity.parameters or [] + # get tool runtime parameters + runtime_parameters = self.tool_runtime.get_runtime_parameters() + # override parameters + current_parameters = tool_parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return current_parameters + + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + mask tool parameters + + return a deep copy of parameters with masked values + """ + parameters = self._deep_copy(parameters) + + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + if len(parameters[parameter.name]) > 6: + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) + else: + parameters[parameter.name] = "*" * len(parameters[parameter.name]) + + return parameters + + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + encrypt tool parameters with tenant id + + return a deep copy of parameters with encrypted values + """ + # override parameters + current_parameters = self._merge_parameters() + + parameters = self._deep_copy(parameters) + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) + parameters[parameter.name] = encrypted + + return parameters + + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + decrypt tool parameters with tenant id + + return a deep copy of parameters with decrypted values + """ + + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id, + ) + cached_parameters = cache.get() + if cached_parameters: + return cached_parameters + + # override parameters + current_parameters = self._merge_parameters() + has_secret_input = False + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + try: + has_secret_input = True + parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) + except Exception: + pass + + if has_secret_input: + cache.set(parameters) + + return parameters + + def delete_tool_parameters_cache(self): + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id, + ) + cache.delete() diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py new file mode 100644 index 0000000000..9bc57235d8 --- /dev/null +++ b/api/core/datasource/utils/message_transformer.py @@ -0,0 +1,121 @@ +import logging +from collections.abc import Generator +from mimetypes import guess_extension +from typing import Optional + +from core.datasource.datasource_file_manager import DatasourceFileManager +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.file import File, FileTransferMethod, FileType + +logger = logging.getLogger(__name__) + + +class DatasourceFileMessageTransformer: + @classmethod + def transform_datasource_invoke_messages( + cls, + messages: Generator[DatasourceMessage, None, None], + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + ) -> Generator[DatasourceMessage, None, None]: + """ + Transform datasource message and handle file download + """ + for message in messages: + if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}: + yield message + elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance( + message.message, DatasourceMessage.TextMessage + ): + # try to download image + try: + assert isinstance(message.message, DatasourceMessage.TextMessage) + + file = DatasourceFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=message.message.text, + conversation_id=conversation_id, + ) + + url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" + + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=message.meta.copy() if message.meta is not None else {}, + ) + except Exception as e: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage( + text=f"Failed to download image: {message.message.text}: {e}" + ), + meta=message.meta.copy() if message.meta is not None else {}, + ) + elif message.type == DatasourceMessage.MessageType.BLOB: + # get mime type and save blob to storage + meta = message.meta or {} + + mimetype = meta.get("mime_type", "application/octet-stream") + # get filename from meta + filename = meta.get("file_name", None) + # if message is str, encode it to bytes + + if not isinstance(message.message, DatasourceMessage.BlobMessage): + raise ValueError("unexpected message type") + + # FIXME: should do a type check here. + assert isinstance(message.message.blob, bytes) + file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message.blob, + mimetype=mimetype, + filename=filename, + ) + + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type)) + + # check if file is image + if "image" in mimetype: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.BINARY_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + elif message.type == DatasourceMessage.MessageType.FILE: + meta = message.meta or {} + file = meta.get("file", None) + if isinstance(file, File): + if file.transfer_method == FileTransferMethod.TOOL_FILE: + assert file.related_id is not None + url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + if file.type == FileType.IMAGE: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield message + else: + yield message + + @classmethod + def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: + return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" diff --git a/api/core/datasource/utils/parser.py b/api/core/datasource/utils/parser.py new file mode 100644 index 0000000000..f72291783a --- /dev/null +++ b/api/core/datasource/utils/parser.py @@ -0,0 +1,389 @@ +import re +import uuid +from json import dumps as json_dumps +from json import loads as json_loads +from json.decoder import JSONDecodeError +from typing import Optional + +from flask import request +from requests import get +from yaml import YAMLError, safe_load # type: ignore + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError + + +class ApiBasedToolSchemaParser: + @staticmethod + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + # set description to extra_info + extra_info["description"] = openapi["info"].get("description", "") + + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") + + server_url = openapi["servers"][0]["url"] + request_env = request.headers.get("X-Request-Env") + if request_env: + matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env] + server_url = matched_servers[0] if matched_servers else server_url + + # list all interfaces + interfaces = [] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] + for method in methods: + if method in path_item: + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) + + # get all parameters + bundles = [] + for interface in interfaces: + # convert parameters + parameters = [] + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: + tool_parameter = ToolParameter( + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), + human_description=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=parameter.get("required", False), + form=ToolParameter.ToolParameterForm.LLM, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) + if typ: + tool_parameter.type = typ + + parameters.append(tool_parameter) + # create tool bundle + # check if there is a request body + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): + # if there is a reference, get the reference and overwrite the content + if "schema" not in content: + continue + + if "$ref" in content["schema"]: + # get the reference + root = openapi + reference = content["schema"]["$ref"].split("/")[1:] + for ref in reference: + root = root[ref] + # overwrite the content + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root + + # parse body parameters + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) + for name, property in properties.items(): + tool = ToolParameter( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + human_description=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=name in required, + form=ToolParameter.ToolParameterForm.LLM, + llm_description=property.get("description", ""), + default=property.get("default", None), + placeholder=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) + if typ: + tool.type = typ + + parameters.append(tool) + + # check if parameters is duplicated + parameters_count = {} + for parameter in parameters: + if parameter.name not in parameters_count: + parameters_count[parameter.name] = 0 + parameters_count[parameter.name] += 1 + for name, count in parameters_count.items(): + if count > 1: + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." + + # check if there is a operation id, use $path_$method as operation id if not + if "operationId" not in interface["operation"]: + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) + if not path: + path = str(uuid.uuid4()) + + interface["operation"]["operationId"] = f"{path}_{interface['method']}" + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) + + return bundles + + @staticmethod + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + parameter = parameter or {} + typ: Optional[str] = None + if parameter.get("format") == "binary": + return ToolParameter.ToolParameterType.FILE + + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ in {"integer", "number"}: + return ToolParameter.ToolParameterType.NUMBER + elif typ == "boolean": + return ToolParameter.ToolParameterType.BOOLEAN + elif typ == "string": + return ToolParameter.ToolParameterType.STRING + elif typ == "array": + items = parameter.get("items") or parameter.get("schema", {}).get("items") + return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None + else: + return None + + @staticmethod + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + """ + parse openapi yaml to tool bundle + + :param yaml: the yaml string + :param extra_info: the extra info + :param warning: the warning message + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + openapi: dict = safe_load(yaml) + if openapi is None: + raise ToolApiSchemaError("Invalid openapi yaml.") + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: + warning = warning or {} + """ + parse swagger to openapi + + :param swagger: the swagger dict + :return: the openapi dict + """ + # convert swagger to openapi + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) + + servers = swagger.get("servers", []) + + if len(servers) == 0: + raise ToolApiSchemaError("No server found in the swagger yaml.") + + openapi = { + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), + }, + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, + } + + # check paths + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") + + # convert paths + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} + for method, operation in path_item.items(): + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), + } + + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] + + # convert definitions + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition + + return openapi + + @staticmethod + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + """ + parse openapi plugin yaml to tool bundle + + :param json: the json string + :param extra_info: the extra info + :param warning: the warning message + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + try: + openai_plugin = json_loads(json) + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] + except JSONDecodeError: + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + + # get openapi yaml + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) + + if response.status_code != 200: + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + + @staticmethod + def auto_parse_to_tool_bundle( + content: str, extra_info: dict | None = None, warning: dict | None = None + ) -> tuple[list[ApiToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :param extra_info: the extra info + :param warning: the warning message + :return: tools bundle, schema_type + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + content = content.strip() + loaded_content = None + json_error = None + yaml_error = None + + try: + loaded_content = json_loads(content) + except JSONDecodeError as e: + json_error = e + + if loaded_content is None: + try: + loaded_content = safe_load(content) + except YAMLError as e: + yaml_error = e + if loaded_content is None: + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," + f" yaml error: {str(yaml_error)}" + ) + + swagger_error = None + openapi_error = None + openapi_plugin_error = None + schema_type = None + + try: + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) + schema_type = ApiProviderSchemaType.OPENAPI.value + return openapi, schema_type + except ToolApiSchemaError as e: + openapi_error = e + + # openai parse error, fallback to swagger + try: + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) + schema_type = ApiProviderSchemaType.SWAGGER.value + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type + except ToolApiSchemaError as e: + swagger_error = e + + # swagger parse error, fallback to openai plugin + try: + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) + return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value + except ToolNotSupportedError as e: + # maybe it's not plugin at all + openapi_plugin_error = e + + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," + f" openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/datasource/utils/text_processing_utils.py b/api/core/datasource/utils/text_processing_utils.py new file mode 100644 index 0000000000..105823f896 --- /dev/null +++ b/api/core/datasource/utils/text_processing_utils.py @@ -0,0 +1,17 @@ +import re + + +def remove_leading_symbols(text: str) -> str: + """ + Remove leading punctuation or symbols from the given text. + + Args: + text (str): The input text to process. + + Returns: + str: The text with leading punctuation or symbols removed. + """ + # Match Unicode ranges for punctuation and symbols + # FIXME this pattern is confused quick fix for #11868 maybe refactor it later + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + return re.sub(pattern, "", text) diff --git a/api/core/datasource/utils/uuid_utils.py b/api/core/datasource/utils/uuid_utils.py new file mode 100644 index 0000000000..3046c08c89 --- /dev/null +++ b/api/core/datasource/utils/uuid_utils.py @@ -0,0 +1,9 @@ +import uuid + + +def is_valid_uuid(uuid_str: str) -> bool: + try: + uuid.UUID(uuid_str) + return True + except Exception: + return False diff --git a/api/core/datasource/utils/workflow_configuration_sync.py b/api/core/datasource/utils/workflow_configuration_sync.py new file mode 100644 index 0000000000..d16d6fc576 --- /dev/null +++ b/api/core/datasource/utils/workflow_configuration_sync.py @@ -0,0 +1,43 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.app.app_config.entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration + + +class WorkflowToolConfigurationUtils: + @classmethod + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): + for configuration in configurations: + WorkflowToolParameterConfiguration.model_validate(configuration) + + @classmethod + def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: + """ + get workflow graph variables + """ + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) + + if not start_node: + return [] + + return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] + + @classmethod + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ): + """ + check is synced + + raise ValueError if not synced + """ + variable_names = [variable.variable for variable in variables] + + if len(tool_configurations) != len(variables): + raise ValueError("parameter configuration mismatch, please republish the tool to update") + + for parameter in tool_configurations: + if parameter.name not in variable_names: + raise ValueError("parameter configuration mismatch, please republish the tool to update") diff --git a/api/core/datasource/utils/yaml_utils.py b/api/core/datasource/utils/yaml_utils.py new file mode 100644 index 0000000000..ee7ca11e05 --- /dev/null +++ b/api/core/datasource/utils/yaml_utils.py @@ -0,0 +1,35 @@ +import logging +from pathlib import Path +from typing import Any + +import yaml # type: ignore +from yaml import YAMLError + +logger = logging.getLogger(__name__) + + +def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: + """ + Safe loading a YAML file + :param file_path: the path of the YAML file + :param ignore_error: + if True, return default_value if error occurs and the error will be logged in debug level + if False, raise error if error occurs + :param default_value: the value returned when errors ignored + :return: an object of the YAML content + """ + if not file_path or not Path(file_path).exists(): + if ignore_error: + return default_value + else: + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, encoding="utf-8") as yaml_file: + try: + yaml_content = yaml.safe_load(yaml_file) + return yaml_content or default_value + except Exception as e: + if ignore_error: + return default_value + else: + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py new file mode 100644 index 0000000000..d0e442f31a --- /dev/null +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -0,0 +1,53 @@ +from collections.abc import Generator, Mapping +from typing import Any + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + WebsiteCrawlMessage, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def get_website_crawl( + self, + user_id: str, + datasource_parameters: Mapping[str, Any], + provider_type: str, + ) -> Generator[WebsiteCrawlMessage, None, None]: + manager = PluginDatasourceManager() + + return manager.get_website_crawl( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.WEBSITE_CRAWL diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py new file mode 100644 index 0000000000..8c0f20ce2d --- /dev/null +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -0,0 +1,52 @@ +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin + + +class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, + entity: DatasourceProviderEntityWithPlugin, + plugin_id: str, + plugin_unique_identifier: str, + tenant_id: str, + ) -> None: + super().__init__(entity, tenant_id) + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.WEBSITE_CRAWL + + def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return WebsiteCrawlDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 90c9879733..63fce06005 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -17,3 +17,27 @@ class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] qa_preview: Optional[list[QAPreviewDetail]] = None + + +class PipelineDataset(BaseModel): + id: str + name: str + description: str + chunk_structure: str + + +class PipelineDocument(BaseModel): + id: str + position: int + data_source_type: str + data_source_info: Optional[dict] = None + name: str + indexing_status: str + error: Optional[str] = None + enabled: bool + + +class PipelineGenerateResponse(BaseModel): + batch: str + dataset: PipelineDataset + documents: list[PipelineDocument] diff --git a/api/core/file/datasource_file_parser.py b/api/core/file/datasource_file_parser.py new file mode 100644 index 0000000000..52687951ac --- /dev/null +++ b/api/core/file/datasource_file_parser.py @@ -0,0 +1,15 @@ +from typing import TYPE_CHECKING, Any, cast + +from core.datasource import datasource_file_manager +from core.datasource.datasource_file_manager import DatasourceFileManager + +if TYPE_CHECKING: + from core.datasource.datasource_file_manager import DatasourceFileManager + +tool_file_manager: dict[str, Any] = {"manager": None} + + +class DatasourceFileParser: + @staticmethod + def get_datasource_file_manager() -> "DatasourceFileManager": + return cast("DatasourceFileManager", datasource_file_manager["manager"]) diff --git a/api/core/file/enums.py b/api/core/file/enums.py index a50a651dd3..170eb4fc23 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" + DATASOURCE_FILE = "datasource_file" @staticmethod def value_of(value): diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 151fa2aaf4..3c0fcb1310 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -135,3 +135,4 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + DATASOURCE_TRACE = "datasource" diff --git a/api/core/plugin/entities/oauth.py b/api/core/plugin/entities/oauth.py new file mode 100644 index 0000000000..d284b82728 --- /dev/null +++ b/api/core/plugin/entities/oauth.py @@ -0,0 +1,21 @@ +from collections.abc import Sequence + +from pydantic import BaseModel, Field + +from core.entities.provider_entities import ProviderConfig + + +class OAuthSchema(BaseModel): + """ + OAuth schema + """ + + client_schema: Sequence[ProviderConfig] = Field( + default_factory=list, + description="client schema like client_id, client_secret, etc.", + ) + + credentials_schema: Sequence[ProviderConfig] = Field( + default_factory=list, + description="credentials schema like access_token, refresh_token, etc.", + ) diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index bdf7d5ce1f..e2ea9669fa 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator from werkzeug.exceptions import NotFound from core.agent.plugin_entities import AgentStrategyProviderEntity +from core.datasource.entities.datasource_entities import DatasourceProviderEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration @@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum): Model = "model" Extension = "extension" AgentStrategy = "agent-strategy" + Datasource = "datasource" class PluginDeclaration(BaseModel): @@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel): tools: Optional[list[str]] = Field(default_factory=list[str]) models: Optional[list[str]] = Field(default_factory=list[str]) endpoints: Optional[list[str]] = Field(default_factory=list[str]) + datasources: Optional[list[str]] = Field(default_factory=list[str]) class Meta(BaseModel): minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") @@ -90,6 +93,7 @@ class PluginDeclaration(BaseModel): model: Optional[ProviderEntity] = None endpoint: Optional[EndpointProviderDeclaration] = None agent_strategy: Optional[AgentStrategyProviderEntity] = None + datasource: Optional[DatasourceProviderEntity] = None meta: Meta @model_validator(mode="before") @@ -100,6 +104,8 @@ class PluginDeclaration(BaseModel): values["category"] = PluginCategory.Tool elif values.get("model"): values["category"] = PluginCategory.Model + elif values.get("datasource"): + values["category"] = PluginCategory.Datasource elif values.get("agent_strategy"): values["category"] = PluginCategory.AgentStrategy else: @@ -193,6 +199,11 @@ class ToolProviderID(GenericProviderID): self.plugin_name = f"{self.provider_name}_tool" +class DatasourceProviderID(GenericProviderID): + def __init__(self, value: str, is_hardcoded: bool = False) -> None: + super().__init__(value, is_hardcoded) + + class PluginDependency(BaseModel): class Type(enum.StrEnum): Github = PluginInstallationSource.Github.value diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 592b42c0da..1db3b6c429 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,7 @@ from typing import Any, Generic, Optional, TypeVar from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity @@ -48,6 +49,14 @@ class PluginToolProviderEntity(BaseModel): declaration: ToolProviderEntityWithPlugin +class PluginDatasourceProviderEntity(BaseModel): + provider: str + plugin_unique_identifier: str + plugin_id: str + is_authorized: bool = False + declaration: DatasourceProviderEntityWithPlugin + + class PluginAgentProviderEntity(BaseModel): provider: str plugin_unique_identifier: str diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py new file mode 100644 index 0000000000..a69e88baf4 --- /dev/null +++ b/api/core/plugin/impl/datasource.py @@ -0,0 +1,329 @@ +from collections.abc import Generator, Mapping +from typing import Any + +from core.datasource.entities.datasource_entities import ( + DatasourceMessage, + GetOnlineDocumentPageContentRequest, + OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, + WebsiteCrawlMessage, +) +from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID +from core.plugin.entities.plugin_daemon import ( + PluginBasicBooleanResponse, + PluginDatasourceProviderEntity, +) +from core.plugin.impl.base import BasePluginClient +from services.tools.tools_transform_service import ToolTransformService + + +class PluginDatasourceManager(BasePluginClient): + def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: + """ + Fetch datasource providers for the given tenant. + """ + + def transformer(json_response: dict[str, Any]) -> dict: + if json_response.get("data"): + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for datasource in declaration.get("datasources", []): + datasource["identity"]["provider"] = provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + list[PluginDatasourceProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + + for provider in response: + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) + all_response = [local_file_datasource_provider] + response + + for provider in all_response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.datasources: + tool.identity.provider = provider.declaration.identity.name + + return all_response + + def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity: + """ + Fetch datasource provider for the given tenant and plugin. + """ + if provider_id == "langgenius/file/file": + return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + + tool_provider_id = DatasourceProviderID(provider_id) + + def transformer(json_response: dict[str, Any]) -> dict: + data = json_response.get("data") + if data: + for datasource in data.get("declaration", {}).get("datasources", []): + datasource["identity"]["provider"] = tool_provider_id.provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasource", + PluginDatasourceProviderEntity, + params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, + transformer=transformer, + ) + + response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for datasource in response.declaration.datasources: + datasource.identity.provider = response.declaration.identity.name + + return response + + def get_website_crawl( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: Mapping[str, Any], + provider_type: str, + ) -> Generator[WebsiteCrawlMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + return self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", + WebsiteCrawlMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "datasource_parameters": datasource_parameters, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + def get_online_document_pages( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: Mapping[str, Any], + provider_type: str, + ) -> Generator[OnlineDocumentPagesMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + return self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", + OnlineDocumentPagesMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "datasource_parameters": datasource_parameters, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + def get_online_document_page_content( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: GetOnlineDocumentPageContentRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + return self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", + DatasourceMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "page": datasource_parameters.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + def online_drive_browse_files( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + request: OnlineDriveBrowseFilesRequest, + provider_type: str, + ) -> Generator[OnlineDriveBrowseFilesResponse, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files", + OnlineDriveBrowseFilesResponse, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "request": request.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + yield from response + + def online_drive_download_file( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + request: OnlineDriveDownloadFileRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file", + DatasourceMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "request": request.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + yield from response + + def validate_provider_credentials( + self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] + ) -> bool: + """ + validate the credentials of the provider + """ + # datasource_provider_id = GenericProviderID(provider_id) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/validate_credentials", + PluginBasicBooleanResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp.result + + return False + + def _get_local_file_datasource_provider(self) -> dict[str, Any]: + return { + "id": "langgenius/file/file", + "plugin_id": "langgenius/file", + "provider": "file", + "plugin_unique_identifier": "langgenius/file:0.0.1@dify", + "declaration": { + "identity": { + "author": "langgenius", + "name": "file", + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, + "icon": "https://assets.dify.ai/images/File%20Upload.svg", + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, + }, + "credentials_schema": [], + "provider_type": "local_file", + "datasources": [ + { + "identity": { + "author": "langgenius", + "name": "upload-file", + "provider": "file", + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, + }, + "parameters": [], + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, + } + ], + }, + } diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..bb9c00005c 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -4,7 +4,10 @@ from typing import Any, Optional from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.entities.plugin_daemon import ( + PluginBasicBooleanResponse, + PluginToolProviderEntity, +) from core.plugin.impl.base import BasePluginClient from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -197,6 +200,36 @@ class PluginToolManager(BasePluginClient): return False + def validate_datasource_credentials( + self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + ) -> bool: + """ + validate the credentials of the datasource + """ + tool_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/validate_credentials", + PluginBasicBooleanResponse, + data={ + "user_id": user_id, + "data": { + "provider": tool_provider_id.provider_name, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": tool_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp.result + + return False + def get_runtime_parameters( self, tenant_id: str, diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index d6d0bd88b2..be1765feee 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -28,10 +28,12 @@ class Jieba(BaseKeyword): with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) + for text in texts: - keywords = keyword_table_handler.extract_keywords( - text.page_content, self._config.max_keywords_per_chunk - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) keyword_table = self._add_text_to_keyword_table( @@ -49,18 +51,17 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords( - text.page_content, self._config.max_keywords_per_chunk - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) else: - keywords = keyword_table_handler.extract_keywords( - text.page_content, self._config.max_keywords_per_chunk - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) keyword_table = self._add_text_to_keyword_table( @@ -239,7 +240,11 @@ class Jieba(BaseKeyword): keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) + + keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table( keyword_table or {}, segment.index_node_id, list(keywords) diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py new file mode 100644 index 0000000000..a36e32fc9c --- /dev/null +++ b/api/core/rag/entities/event.py @@ -0,0 +1,38 @@ +from collections.abc import Mapping +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class DatasourceStreamEvent(Enum): + """ + Datasource Stream event + """ + + PROCESSING = "datasource_processing" + COMPLETED = "datasource_completed" + ERROR = "datasource_error" + + +class BaseDatasourceEvent(BaseModel): + pass + + +class DatasourceErrorEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.ERROR.value + error: str = Field(..., description="error message") + + +class DatasourceCompletedEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.COMPLETED.value + data: Mapping[str, Any] | list = Field(..., description="result") + total: Optional[int] = Field(default=0, description="total") + completed: Optional[int] = Field(default=0, description="completed") + time_consuming: Optional[float] = Field(default=0.0, description="time consuming") + + +class DatasourceProcessingEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.PROCESSING.value + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index c8ad53e3dd..05fbf9003b 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -13,3 +13,5 @@ class MetadataDataSource(Enum): upload_file = "file_upload" website_crawl = "website" notion_import = "notion" + local_file = "file_upload" + online_document = "online_document" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 2bcd1c79bb..ff6f843a28 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,7 +1,8 @@ """Abstract interface for document loader implementations.""" from abc import ABC, abstractmethod -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from configs import dify_config from core.model_manager import ModelInstance @@ -13,6 +14,7 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Document as DatasetDocument class BaseIndexProcessor(ABC): @@ -33,6 +35,14 @@ class BaseIndexProcessor(ABC): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): raise NotImplementedError + @abstractmethod + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): + raise NotImplementedError + + @abstractmethod + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + raise NotImplementedError + @abstractmethod def retrieve( self, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9b90bd2bb3..6a114b6bb2 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,19 +1,22 @@ """Paragraph index processor.""" import uuid -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -127,3 +130,34 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs + + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): + paragraph = GeneralStructureChunk(**chunks) + documents = [] + for content in paragraph.general_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(content), + } + doc = Document(page_content=content, metadata=metadata) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + paragraph = GeneralStructureChunk(**chunks) + preview = [] + for content in paragraph.general_chunks: + preview.append({"content": content}) + return {"preview": preview, "total_segments": len(paragraph.general_chunks)} diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 1cde5e1c8f..158fc819ee 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,20 +1,23 @@ """Paragraph index processor.""" import uuid -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from configs import dify_config from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -202,3 +205,40 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_document.page_content = child_page_content child_nodes.append(child_document) return child_nodes + + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): + parent_childs = ParentChildStructureChunk(**chunks) + documents = [] + for parent_child in parent_childs.parent_child_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(parent_child.parent_content), + } + child_documents = [] + for child in parent_child.child_contents: + child_metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(child), + } + child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) + doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=True) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + parent_childs = ParentChildStructureChunk(**chunks) + preview = [] + for parent_child in parent_childs.parent_child_chunks: + preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) + return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)} diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 75f3153697..88253b409d 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,8 @@ import logging import re import threading import uuid -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional import pandas as pd from flask import Flask, current_app @@ -14,13 +15,15 @@ from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import Document, QAStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -161,6 +164,36 @@ class QAIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): + qa_chunks = QAStructureChunk(**chunks) + documents = [] + for qa_chunk in qa_chunks.qa_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(qa_chunk.question), + "answer": qa_chunk.answer, + } + doc = Document(page_content=qa_chunk.question, metadata=metadata) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + else: + raise ValueError("Indexing technique must be high quality.") + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + qa_chunks = QAStructureChunk(**chunks) + preview = [] + for qa_chunk in qa_chunks.qa_chunks: + preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) + return {"qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks)} + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 04a3428ad8..e382ff6b54 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -35,6 +35,48 @@ class Document(BaseModel): children: Optional[list[ChildDocument]] = None +class GeneralStructureChunk(BaseModel): + """ + General Structure Chunk. + """ + + general_chunks: list[str] + + +class ParentChildChunk(BaseModel): + """ + Parent Child Chunk. + """ + + parent_content: str + child_contents: list[str] + + +class ParentChildStructureChunk(BaseModel): + """ + Parent Child Structure Chunk. + """ + + parent_child_chunks: list[ParentChildChunk] + + +class QAChunk(BaseModel): + """ + QA Chunk. + """ + + question: str + answer: str + + +class QAStructureChunk(BaseModel): + """ + QAStructureChunk. + """ + + qa_chunks: list[QAChunk] + + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index eaa00bca88..c7c6e60c8d 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -5,6 +5,7 @@ class RetrievalMethod(Enum): SEMANTIC_SEARCH = "semantic_search" FULL_TEXT_SEARCH = "full_text_search" HYBRID_SEARCH = "hybrid_search" + KEYWORD_SEARCH = "keyword_search" @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 797cce9354..46ff9e63a4 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -262,6 +262,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, + triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecutionModel]: """ Retrieve all WorkflowNodeExecution database models for a specific workflow run. @@ -283,7 +284,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) stmt = select(WorkflowNodeExecutionModel).where( WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + WorkflowNodeExecutionModel.triggered_from == triggered_from, ) if self._app_id: @@ -317,6 +318,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, + triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. @@ -334,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) A list of NodeExecution instances """ # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) + db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) # Convert database models to domain models domain_models = [] diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682e..e5dc226571 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from uuid import uuid4 -from pydantic import Field +from pydantic import BaseModel, Field from core.helper import encrypter @@ -93,3 +93,32 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + + +class RAGPipelineVariable(BaseModel): + belong_to_node_id: str = Field(description="belong to which node id, shared means public") + type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") + label: str = Field(description="label") + description: str | None = Field(description="description", default="") + variable: str = Field(description="variable key", default="") + max_length: int | None = Field( + description="max length, applicable to text-input, paragraph, and file-list", default=0 + ) + default_value: Any = Field(description="default value", default="") + placeholder: str | None = Field(description="placeholder", default="") + unit: str | None = Field(description="unit, applicable to Number", default="") + tooltips: str | None = Field(description="helpful text", default="") + allowed_file_types: list[str] | None = Field( + description="image, document, audio, video, custom.", default_factory=list + ) + allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) + allowed_file_upload_methods: list[str] | None = Field( + description="remote_url, local_file, tool_file.", default_factory=list + ) + required: bool = Field(description="optional, default false", default=False) + options: list[str] | None = Field(default_factory=list) + + +class RAGPipelineVariableInput(BaseModel): + variable: RAGPipelineVariable + value: Any diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py index e3fe17c284..7664be0983 100644 --- a/api/core/workflow/constants.py +++ b/api/core/workflow/constants.py @@ -1,3 +1,4 @@ SYSTEM_VARIABLE_NODE_ID = "sys" ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" +RAG_PIPELINE_VARIABLE_NODE_ID = "rag" diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 80dda2632d..3a68f45f61 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -9,7 +9,13 @@ from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.variables.variables import RAGPipelineVariableInput +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from core.workflow.enums import SystemVariableKey from factories import variable_factory @@ -44,6 +50,10 @@ class VariablePool(BaseModel): description="Conversation variables.", default_factory=list, ) + rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( + description="RAG pipeline variables.", + default_factory=list, + ) def model_post_init(self, context: Any, /) -> None: for key, value in self.system_variables.items(): @@ -54,6 +64,9 @@ class VariablePool(BaseModel): # Add conversation variables to the variable pool for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + # Add rag pipeline variables to the variable pool + for var in self.rag_pipeline_variables: + self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value) def add(self, selector: Sequence[str], value: Any, /) -> None: """ diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index 781be4b3c6..9d70dd0ab6 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -20,6 +20,7 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" + RAG_PIPELINE = "rag-pipeline" class WorkflowExecutionStatus(StrEnum): diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index 09a408f4d7..bc08c9df85 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): AGENT_LOG = "agent_log" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" + DATASOURCE_INFO = "datasource_info" LOOP_ID = "loop_id" LOOP_INDEX = "loop_index" PARALLEL_ID = "parallel_id" diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index b52a2b0e6e..d3e14d33ba 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -14,3 +14,10 @@ class SystemVariableKey(StrEnum): APP_ID = "app_id" WORKFLOW_ID = "workflow_id" WORKFLOW_EXECUTION_ID = "workflow_run_id" + # RAG Pipeline + DOCUMENT_ID = "document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 8e5b1e7142..16bf847189 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -121,6 +121,7 @@ class Graph(BaseModel): # fetch nodes that have no predecessor node root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} + for node_config in node_configs: node_id = node_config.get("id") if not node_id: @@ -141,6 +142,7 @@ class Graph(BaseModel): node_config.get("id") for node_config in root_node_configs if node_config.get("data", {}).get("type", "") == NodeType.START.value + or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value ), None, ) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 61a7a26652..eaa558b02c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -175,7 +175,7 @@ class GraphEngine: ) return elif isinstance(item, NodeRunSucceededEvent): - if item.node_type == NodeType.END: + if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX): self.graph_runtime_state.outputs = ( dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result @@ -320,10 +320,10 @@ class GraphEngine: raise e # It may not be necessary, but it is necessary. :) - if ( - self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() - == NodeType.END.value - ): + if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [ + NodeType.END.value, + NodeType.KNOWLEDGE_INDEX.value, + ]: break previous_route_node_state = route_node_state diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py new file mode 100644 index 0000000000..f6ec44cb77 --- /dev/null +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -0,0 +1,3 @@ +from .datasource_node import DatasourceNode + +__all__ = ["DatasourceNode"] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py new file mode 100644 index 0000000000..01f6f51648 --- /dev/null +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -0,0 +1,468 @@ +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.datasource.entities.datasource_entities import ( + DatasourceMessage, + DatasourceParameter, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + OnlineDriveDownloadFileRequest, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from core.file import File +from core.file.enums import FileTransferMethod, FileType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.variables.segments import ArrayAnySegment +from core.variables.variables import ArrayAnyVariable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.tool.exc import ToolFileError +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from factories import file_factory +from models.model import UploadFile +from services.datasource_provider_service import DatasourceProviderService + +from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from .entities import DatasourceNodeData +from .exc import DatasourceNodeError, DatasourceParameterError + + +class DatasourceNode(BaseNode[DatasourceNodeData]): + """ + Datasource Node + """ + + _node_data_cls = DatasourceNodeData + _node_type = NodeType.DATASOURCE + + def _run(self) -> Generator: + """ + Run the datasource node + """ + + node_data = cast(DatasourceNodeData, self.node_data) + variable_pool = self.graph_runtime_state.variable_pool + datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + if not datasource_type: + raise DatasourceNodeError("Datasource type is not set") + datasource_type = datasource_type.value + datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + if not datasource_info: + raise DatasourceNodeError("Datasource info is not set") + datasource_info = datasource_info.value + # get datasource runtime + try: + from core.datasource.datasource_manager import DatasourceManager + + if datasource_type is None: + raise DatasourceNodeError("Datasource type is not set") + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", + datasource_name=node_data.datasource_name or "", + tenant_id=self.tenant_id, + datasource_type=DatasourceProviderType.value_of(datasource_type), + ) + except DatasourceNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) + ) + + # get parameters + datasource_parameters = datasource_runtime.entity.parameters + parameters = self._generate_parameters( + datasource_parameters=datasource_parameters, + variable_pool=variable_pool, + node_data=self.node_data, + ) + parameters_for_log = self._generate_parameters( + datasource_parameters=datasource_parameters, + variable_pool=variable_pool, + node_data=self.node_data, + for_log=True, + ) + + try: + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=datasource_info.get("workspace_id"), + page_id=datasource_info.get("page").get("page_id"), + type=datasource_info.get("page").get("type"), + ), + provider_type=datasource_type, + ) + ) + yield from self._transform_message( + messages=online_document_result, + parameters_for_log=parameters_for_log, + datasource_info=datasource_info, + ) + case DatasourceProviderType.ONLINE_DRIVE: + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + online_drive_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.online_drive_download_file( + user_id=self.user_id, + request=OnlineDriveDownloadFileRequest( + key=datasource_info.get("key"), + bucket=datasource_info.get("bucket"), + ), + provider_type=datasource_type, + ) + ) + yield from self._transform_message( + messages=online_drive_result, + parameters_for_log=parameters_for_log, + datasource_info=datasource_info, + ) + case DatasourceProviderType.WEBSITE_CRAWL: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + **datasource_info, + "datasource_type": datasource_type, + }, + ) + ) + case DatasourceProviderType.LOCAL_FILE: + related_id = datasource_info.get("related_id") + if not related_id: + raise DatasourceNodeError("File is not exist") + upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() + if not upload_file: + raise ValueError("Invalid upload file Info") + + file_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=FileType.CUSTOM, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + variable_pool.add([self.node_id, "file"], [file_info]) + for key, value in datasource_info.items(): + # construct new key list + new_key_list = ["file", key] + self._append_variables_recursively( + variable_pool=variable_pool, + node_id=self.node_id, + variable_key_list=new_key_list, + variable_value=value, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file_info": datasource_info, + "datasource_type": datasource_type, + }, + ) + ) + case _: + raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") + except PluginDaemonClientSideError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", + error_type=type(e).__name__, + ) + ) + except DatasourceNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to invoke datasource: {str(e)}", + error_type=type(e).__name__, + ) + ) + + def _generate_parameters( + self, + *, + datasource_parameters: Sequence[DatasourceParameter], + variable_pool: VariablePool, + node_data: DatasourceNodeData, + for_log: bool = False, + ) -> dict[str, Any]: + """ + Generate parameters based on the given tool parameters, variable pool, and node data. + + Args: + tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (ToolNodeData): The data associated with the tool node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} + + result: dict[str, Any] = {} + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + parameter = datasource_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + datasource_input = node_data.datasource_parameters[parameter_name] + if datasource_input.type == "variable": + variable = variable_pool.get(datasource_input.value) + if variable is None: + raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") + parameter_value = variable.value + elif datasource_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(datasource_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") + result[parameter_name] = parameter_value + + return result + + def _fetch_files(self, variable_pool: VariablePool) -> list[File]: + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) + return list(variable.value) if variable else [] + + def _append_variables_recursively( + self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue + ): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + variable_pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: DatasourceNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + result = {} + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] + if input.type == "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + elif input.type == "variable": + result[parameter_name] = input.value + elif input.type == "constant": + pass + + result = {node_id + "." + key: value for key, value in result.items()} + + return result + + def _transform_message( + self, + messages: Generator[DatasourceMessage, None, None], + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + DatasourceMessage.MessageType.IMAGE_LINK, + DatasourceMessage.MessageType.BINARY_LINK, + DatasourceMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, DatasourceMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE) + else: + transfer_method = FileTransferMethod.DATASOURCE_FILE + + datasource_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + + mapping = { + "datasource_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + elif message.type == DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "datasource_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.DATASOURCE_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + elif message.type == DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent( + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + ) + elif message.type == DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + if self.node_type == NodeType.AGENT: + msg_metadata = message.message.json_object.pop("execution_metadata", {}) + agent_execution_metadata = { + key: value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + json.append(message.message.json_object) + elif message.type == DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + elif message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"json": json, "files": files, **variables, "text": text}, + metadata={ + WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info, + }, + inputs=parameters_for_log, + ) + ) + + @classmethod + def version(cls) -> str: + return "1" diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py new file mode 100644 index 0000000000..b182928baa --- /dev/null +++ b/api/core/workflow/nodes/datasource/entities.py @@ -0,0 +1,41 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + +from core.workflow.nodes.base.entities import BaseNodeData + + +class DatasourceEntity(BaseModel): + plugin_id: str + provider_name: str # redundancy + provider_type: str + datasource_name: Optional[str] = "local_file" + datasource_configurations: dict[str, Any] | None = None + plugin_unique_identifier: str | None = None # redundancy + + +class DatasourceNodeData(BaseNodeData, DatasourceEntity): + class DatasourceInput(BaseModel): + # TODO: check this type + value: Union[Any, list[str]] + type: Optional[Literal["mixed", "variable", "constant"]] = None + + @field_validator("type", mode="before") + @classmethod + def check_type(cls, value, validation_info: ValidationInfo): + typ = value + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": + if not isinstance(value, list): + raise ValueError("value must be a list") + for val in value: + if not isinstance(val, str): + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") + return typ + + datasource_parameters: dict[str, DatasourceInput] | None = None diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py new file mode 100644 index 0000000000..89980e6f45 --- /dev/null +++ b/api/core/workflow/nodes/datasource/exc.py @@ -0,0 +1,16 @@ +class DatasourceNodeError(ValueError): + """Base exception for datasource node errors.""" + + pass + + +class DatasourceParameterError(DatasourceNodeError): + """Exception raised for errors in datasource parameters.""" + + pass + + +class DatasourceFileError(DatasourceNodeError): + """Exception raised for errors related to datasource files.""" + + pass diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..7edc73b6ba 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -7,12 +7,14 @@ class NodeType(StrEnum): ANSWER = "answer" LLM = "llm" KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + KNOWLEDGE_INDEX = "knowledge-index" IF_ELSE = "if-else" CODE = "code" TEMPLATE_TRANSFORM = "template-transform" QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" TOOL = "tool" + DATASOURCE = "datasource" VARIABLE_AGGREGATOR = "variable-aggregator" LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. LOOP = "loop" diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py new file mode 100644 index 0000000000..23897a1e42 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_index_node import KnowledgeIndexNode + +__all__ = ["KnowledgeIndexNode"] diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py new file mode 100644 index 0000000000..18a4f93970 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -0,0 +1,159 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel + +from core.workflow.nodes.base import BaseNodeData + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + reranking_provider_name: str + reranking_model_name: str + + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class RetrievalSetting(BaseModel): + """ + Retrieval Setting. + """ + + search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"] + top_k: int + score_threshold: Optional[float] = 0.5 + score_threshold_enabled: bool = False + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None + + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting + + +class FileInfo(BaseModel): + """ + File Info. + """ + + file_id: str + + +class OnlineDocumentIcon(BaseModel): + """ + Document Icon. + """ + + icon_url: str + icon_type: str + icon_emoji: str + + +class OnlineDocumentInfo(BaseModel): + """ + Online document info. + """ + + provider: str + workspace_id: str + page_id: str + page_type: str + icon: OnlineDocumentIcon + + +class WebsiteInfo(BaseModel): + """ + website import info. + """ + + provider: str + url: str + + +class GeneralStructureChunk(BaseModel): + """ + General Structure Chunk. + """ + + general_chunks: list[str] + data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] + + +class ParentChildChunk(BaseModel): + """ + Parent Child Chunk. + """ + + parent_content: str + child_contents: list[str] + + +class ParentChildStructureChunk(BaseModel): + """ + Parent Child Structure Chunk. + """ + + parent_child_chunks: list[ParentChildChunk] + data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] + + +class KnowledgeIndexNodeData(BaseNodeData): + """ + Knowledge index Node Data. + """ + + type: str = "knowledge-index" + chunk_structure: str + index_chunk_variable_selector: list[str] diff --git a/api/core/workflow/nodes/knowledge_index/exc.py b/api/core/workflow/nodes/knowledge_index/exc.py new file mode 100644 index 0000000000..afdde9c0c5 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/exc.py @@ -0,0 +1,22 @@ +class KnowledgeIndexNodeError(ValueError): + """Base class for KnowledgeIndexNode errors.""" + + +class ModelNotExistError(KnowledgeIndexNodeError): + """Raised when the model does not exist.""" + + +class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError): + """Raised when the model credentials are not initialized.""" + + +class ModelNotSupportedError(KnowledgeIndexNodeError): + """Raised when the model is not supported.""" + + +class ModelQuotaExceededError(KnowledgeIndexNodeError): + """Raised when the model provider quota is exceeded.""" + + +class InvalidModelTypeError(KnowledgeIndexNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py new file mode 100644 index 0000000000..ad89a7ad08 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -0,0 +1,165 @@ +import datetime +import logging +import time +from collections.abc import Mapping +from typing import Any, cast + +from sqlalchemy import func + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes.enums import NodeType +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +from ..base import BaseNode +from .entities import KnowledgeIndexNodeData +from .exc import ( + KnowledgeIndexNodeError, +) + +logger = logging.getLogger(__name__) + +default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, +} + + +class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): + _node_data_cls = KnowledgeIndexNodeData # type: ignore + _node_type = NodeType.KNOWLEDGE_INDEX + + def _run(self) -> NodeRunResult: # type: ignore + node_data = cast(KnowledgeIndexNodeData, self.node_data) + variable_pool = self.graph_runtime_state.variable_pool + dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + if not dataset_id: + raise KnowledgeIndexNodeError("Dataset ID is required.") + dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.") + + # extract variables + variable = variable_pool.get(node_data.index_chunk_variable_selector) + if not variable: + raise KnowledgeIndexNodeError("Index chunk variable is required.") + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + if invoke_from: + is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value + else: + is_preview = False + chunks = variable.value + variables = {"chunks": chunks} + if not chunks: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." + ) + + # index knowledge + try: + if is_preview: + outputs = self._get_preview_output(node_data.chunk_structure, chunks) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs, + ) + results = self._invoke_knowledge_index( + dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results + ) + + except KnowledgeIndexNodeError as e: + logger.warning("Error when running knowledge index node") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + # Temporary handle all exceptions from DatasetRetrieval class here. + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + + def _invoke_knowledge_index( + self, + dataset: Dataset, + node_data: KnowledgeIndexNodeData, + chunks: Mapping[str, Any], + variable_pool: VariablePool, + ) -> Any: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if not document_id: + raise KnowledgeIndexNodeError("Document ID is required.") + batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + if not batch: + raise KnowledgeIndexNodeError("Batch is required.") + document = db.session.query(Document).filter_by(id=document_id.value).first() + if not document: + raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.") + # chunk nodes by chunk size + indexing_start_at = time.perf_counter() + index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() + index_processor.index(dataset, document, chunks) + indexing_end_at = time.perf_counter() + document.indexing_latency = indexing_end_at - indexing_start_at + # update document status + document.indexing_status = "completed" + document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.word_count = ( + db.session.query(func.sum(DocumentSegment.word_count)) + .filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ) + .scalar() + ) + db.session.add(document) + # update document segment status + db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + + db.session.commit() + + return { + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "batch": batch.value, + "document_id": document.id, + "document_name": document.name, + "created_at": document.created_at.timestamp(), + "display_status": document.indexing_status, + } + + def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + return index_processor.format_preview(chunks) + + @classmethod + def version(cls) -> str: + return "1" diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 19bdee4fe2..cb2c191518 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -57,10 +57,6 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): - """ - Model Config. - """ - provider: str name: str mode: str diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 67cc884f20..f7ec8fe737 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code import CodeNode +from core.workflow.nodes.datasource.datasource_node import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode from core.workflow.nodes.end import EndNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request import HttpRequestNode from core.workflow.nodes.if_else import IfElseNode from core.workflow.nodes.iteration import IterationNode, IterationStartNode +from core.workflow.nodes.knowledge_index import KnowledgeIndexNode from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.list_operator import ListOperatorNode from core.workflow.nodes.llm import LLMNode @@ -124,4 +126,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { LATEST_VERSION: AgentNode, "1": AgentNode, }, + NodeType.DATASOURCE: { + LATEST_VERSION: DatasourceNode, + "1": DatasourceNode, + }, + NodeType.KNOWLEDGE_INDEX: { + LATEST_VERSION: KnowledgeIndexNode, + "1": KnowledgeIndexNode, + }, } diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 25d1390492..4f25cc64b0 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -61,6 +61,7 @@ def build_from_mapping( FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.REMOTE_URL: _build_from_remote_url, FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, } build_func = build_functions.get(transfer_method) @@ -305,6 +306,53 @@ def _build_from_tool_file( ) +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, +) -> File: + datasource_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.tenant_id == tenant_id, + ) + .first() + ) + + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + + detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) + + return File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=datasource_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=datasource_file.source_url, + related_id=datasource_file.id, + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + ) + + def _is_file_valid_with_config( *, input_file_type: str, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 250ee4695e..8dd4ec2e4a 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -36,7 +36,10 @@ from core.variables.variables import ( StringVariable, Variable, ) -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) class UnsupportedSegmentTypeError(Exception): @@ -75,6 +78,12 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if not mapping.get("variable"): + raise VariableError("missing variable") + return mapping["variable"] + + def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: """ This factory function is used to create the environment variable or the conversation variable, diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 32a88cc5db..79a4f1c6de 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -56,6 +56,13 @@ external_knowledge_info_fields = { doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} +icon_info_fields = { + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": fields.String, +} + dataset_detail_fields = { "id": fields.String, "name": fields.String, @@ -81,6 +88,13 @@ dataset_detail_fields = { "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)), "built_in_field_enabled": fields.Boolean, + "pipeline_id": fields.String, + "runtime_mode": fields.String, + "chunk_structure": fields.String, + "icon_info": fields.Nested(icon_info_fields), + "is_published": fields.Boolean, + "total_documents": fields.Integer, + "total_available_documents": fields.Integer, } dataset_query_detail_fields = { diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py new file mode 100644 index 0000000000..cedc13ed0d --- /dev/null +++ b/api/fields/rag_pipeline_fields.py @@ -0,0 +1,164 @@ +from flask_restful import fields # type: ignore + +from fields.workflow_fields import workflow_partial_fields +from libs.helper import AppIconUrlField, TimestampField + +pipeline_detail_kernel_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, +} + +related_app_list = { + "data": fields.List(fields.Nested(pipeline_detail_kernel_fields)), + "total": fields.Integer, +} + +app_detail_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "tracing": fields.Raw, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + + +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + +app_partial_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String(attribute="desc_or_prompt"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), +} + + +app_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), +} + +template_fields = { + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, +} + +template_list_fields = { + "data": fields.List(fields.Nested(template_fields)), +} + +site_fields = { + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +deleted_tool_fields = { + "type": fields.String, + "tool_name": fields.String, + "provider_id": fields.String, +} + +app_detail_fields_with_site = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + + +app_site_fields = { + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, +} + +leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String} + +pipeline_import_fields = { + "id": fields.String, + "status": fields.String, + "pipeline_id": fields.String, + "dataset_id": fields.String, + "current_dsl_version": fields.String, + "imported_dsl_version": fields.String, + "error": fields.String, +} + +pipeline_import_check_dependencies_fields = { + "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)), +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9f1bef3b36..36249d2ae9 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -40,6 +40,23 @@ conversation_variable_fields = { "description": fields.String, } +pipeline_variable_fields = { + "label": fields.String, + "variable": fields.String, + "type": fields.String, + "belong_to_node_id": fields.String, + "max_length": fields.Integer, + "required": fields.Boolean, + "unit": fields.String, + "default_value": fields.Raw, + "options": fields.List(fields.String), + "placeholder": fields.String, + "tooltips": fields.String, + "allowed_file_types": fields.List(fields.String), + "allow_file_extension": fields.List(fields.String), + "allow_file_upload_methods": fields.List(fields.String), +} + workflow_fields = { "id": fields.String, "graph": fields.Raw(attribute="graph_dict"), @@ -55,6 +72,7 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), + "rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)), } workflow_partial_fields = { diff --git a/api/installed_plugins.jsonl b/api/installed_plugins.jsonl new file mode 100644 index 0000000000..463e24ae64 --- /dev/null +++ b/api/installed_plugins.jsonl @@ -0,0 +1 @@ +{"not_installed": [], "plugin_install_failed": []} \ No newline at end of file diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py new file mode 100644 index 0000000000..961589a87e --- /dev/null +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -0,0 +1,113 @@ +"""add_pipeline_info + +Revision ID: b35c3db83d09 +Revises: d28f2004b072 +Create Date: 2025-05-15 15:58:05.179877 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b35c3db83d09' +down_revision = '0ab65e1cc7fa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('pipeline_built_in_templates', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('icon', sa.JSON(), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=False), + sa.Column('privacy_policy', sa.String(length=255), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('language', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') + ) + op.create_table('pipeline_customized_templates', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('icon', sa.JSON(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('language', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') + ) + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('pipelines', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=True), + sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_pkey') + ) + op.create_table('tool_builtin_datasource_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=256), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider') + ) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True)) + batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True)) + batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('rag_pipeline_variables') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('chunk_structure') + batch_op.drop_column('pipeline_id') + batch_op.drop_column('runtime_mode') + batch_op.drop_column('icon_info') + batch_op.drop_column('keyword_number') + + op.drop_table('tool_builtin_datasource_providers') + op.drop_table('pipelines') + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.drop_index('pipeline_customized_template_tenant_idx') + + op.drop_table('pipeline_customized_templates') + op.drop_table('pipeline_built_in_templates') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py new file mode 100644 index 0000000000..ae8e832d26 --- /dev/null +++ b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_2 + +Revision ID: abb18a379e62 +Revises: b35c3db83d09 +Create Date: 2025-05-16 16:59:16.423127 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'abb18a379e62' +down_revision = 'b35c3db83d09' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipelines', schema=None) as batch_op: + batch_op.drop_column('mode') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipelines', schema=None) as batch_op: + batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False)) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py b/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py new file mode 100644 index 0000000000..0b010d535d --- /dev/null +++ b/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py @@ -0,0 +1,70 @@ +"""add_pipeline_info_3 + +Revision ID: c459994abfa8 +Revises: abb18a379e62 +Create Date: 2025-05-30 00:33:14.068312 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c459994abfa8' +down_revision = 'abb18a379e62' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('datasource_oauth_params', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx') + ) + op.create_table('datasource_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('auth_type', sa.String(length=255), nullable=False), + sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='datasource_provider_plugin_id_provider_idx') + ) + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False)) + batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False)) + batch_op.drop_column('pipeline_id') + + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False)) + batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False)) + batch_op.drop_column('pipeline_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.drop_column('yaml_content') + batch_op.drop_column('chunk_structure') + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.drop_column('yaml_content') + batch_op.drop_column('chunk_structure') + + op.drop_table('datasource_providers') + op.drop_table('datasource_oauth_params') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py b/api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py new file mode 100644 index 0000000000..5c10608c1b --- /dev/null +++ b/api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py @@ -0,0 +1,37 @@ +"""add_pipeline_info_4 + +Revision ID: e4fb49a4fe86 +Revises: c459994abfa8 +Create Date: 2025-05-30 00:52:49.222558 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'e4fb49a4fe86' +down_revision = 'c459994abfa8' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.TEXT(), + existing_nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.TEXT(), + type_=sa.UUID(), + existing_nullable=False) + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py b/api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py new file mode 100644 index 0000000000..56860d1f80 --- /dev/null +++ b/api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py @@ -0,0 +1,35 @@ +"""add_pipeline_info_5 + +Revision ID: d466c551816f +Revises: e4fb49a4fe86 +Create Date: 2025-06-05 13:56:05.962215 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd466c551816f' +down_revision = 'e4fb49a4fe86' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('datasource_provider_plugin_id_provider_idx'), type_='unique') + batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_index('datasource_provider_auth_type_provider_idx') + batch_op.create_unique_constraint(batch_op.f('datasource_provider_plugin_id_provider_idx'), ['plugin_id', 'provider']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py b/api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py new file mode 100644 index 0000000000..d2cd61f9ec --- /dev/null +++ b/api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py @@ -0,0 +1,43 @@ +"""add_pipeline_info_6 + +Revision ID: 224fba149d48 +Revises: d466c551816f +Create Date: 2025-06-11 11:55:01.179201 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '224fba149d48' +down_revision = 'd466c551816f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py new file mode 100644 index 0000000000..a695adc74a --- /dev/null +++ b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py @@ -0,0 +1,45 @@ +"""add_pipeline_info_7 + +Revision ID: 70a0fc0c013f +Revises: 224fba149d48 +Create Date: 2025-06-17 19:05:39.920953 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '70a0fc0c013f' +down_revision = '224fba149d48' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('document_pipeline_execution_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('datasource_type', sa.String(length=255), nullable=False), + sa.Column('datasource_info', sa.Text(), nullable=False), + sa.Column('input_data', sa.JSON(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey') + ) + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.drop_index('document_pipeline_execution_logs_document_id_idx') + + op.drop_table('document_pipeline_execution_logs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py b/api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py new file mode 100644 index 0000000000..387aff54b0 --- /dev/null +++ b/api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_8 + +Revision ID: a1025f709c06 +Revises: 70a0fc0c013f +Create Date: 2025-06-19 15:25:41.263120 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a1025f709c06' +down_revision = '70a0fc0c013f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.add_column(sa.Column('datasource_node_id', sa.String(length=255), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.drop_column('datasource_node_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py b/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py new file mode 100644 index 0000000000..82c5991775 --- /dev/null +++ b/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_9 + +Revision ID: 15e40b74a6d2 +Revises: a1025f709c06 +Create Date: 2025-07-02 11:32:44.125790 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '15e40b74a6d2' +down_revision = 'a1025f709c06' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(length=255), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_column('name') + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 83b50eb099..49b1b9254f 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -56,6 +56,7 @@ from .model import ( TraceAppConfig, UploadFile, ) +from .oauth import DatasourceOauthParamConfig, DatasourceProvider from .provider import ( LoadBalancingModelConfig, Provider, @@ -121,6 +122,8 @@ __all__ = [ "DatasetProcessRule", "DatasetQuery", "DatasetRetrieverResource", + "DatasourceOauthParamConfig", + "DatasourceProvider", "DifySetup", "Document", "DocumentSegment", diff --git a/api/models/dataset.py b/api/models/dataset.py index 1ec27203a0..5017472e89 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -60,9 +60,31 @@ class Dataset(Base): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) + keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) collection_binding_id = db.Column(StringUUID, nullable=True) retrieval_model = db.Column(JSONB, nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + icon_info = db.Column(JSONB, nullable=True) + runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) + pipeline_id = db.Column(StringUUID, nullable=True) + chunk_structure = db.Column(db.String(255), nullable=True) + + @property + def total_documents(self): + return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + + @property + def total_available_documents(self): + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def dataset_keyword_table(self): @@ -147,6 +169,8 @@ class Dataset(Base): @property def doc_form(self): + if self.chunk_structure: + return self.chunk_structure document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form @@ -202,6 +226,14 @@ class Dataset(Base): "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), } + @property + def is_published(self): + if self.pipeline_id: + pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + if pipeline: + return pipeline.is_published + return False + @property def doc_metadata(self): dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() @@ -1142,3 +1174,100 @@ class DatasetMetadataBinding(Base): document_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by = db.Column(StringUUID, nullable=False) + + +class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] + __tablename__ = "pipeline_built_in_templates" + __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False) + chunk_structure = db.Column(db.String(255), nullable=False) + icon = db.Column(db.JSON, nullable=False) + yaml_content = db.Column(db.Text, nullable=False) + copyright = db.Column(db.String(255), nullable=False) + privacy_policy = db.Column(db.String(255), nullable=False) + position = db.Column(db.Integer, nullable=False) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = db.Column(StringUUID, nullable=False) + updated_by = db.Column(StringUUID, nullable=True) + + @property + def created_user_name(self): + account = db.session.query(Account).filter(Account.id == self.created_by).first() + if account: + return account.name + return "" + + +class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] + __tablename__ = "pipeline_customized_templates" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), + db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False) + chunk_structure = db.Column(db.String(255), nullable=False) + icon = db.Column(db.JSON, nullable=False) + position = db.Column(db.Integer, nullable=False) + yaml_content = db.Column(db.Text, nullable=False) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False) + created_by = db.Column(StringUUID, nullable=False) + updated_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def created_user_name(self): + account = db.session.query(Account).filter(Account.id == self.created_by).first() + if account: + return account.name + return "" + + +class Pipeline(Base): # type: ignore[name-defined] + __tablename__ = "pipelines" + __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) + workflow_id = db.Column(StringUUID, nullable=True) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() + + +class DocumentPipelineExecutionLog(Base): + __tablename__ = "document_pipeline_execution_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"), + db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + pipeline_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + datasource_type = db.Column(db.String(255), nullable=False) + datasource_info = db.Column(db.Text, nullable=False) + datasource_node_id = db.Column(db.String(255), nullable=False) + input_data = db.Column(db.JSON, nullable=False) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/enums.py b/api/models/enums.py index 4434c3fec8..0afa204b1f 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -14,6 +14,8 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" + RAG_PIPELINE_RUN = "rag-pipeline-run" + RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging" class DraftVariableType(StrEnum): diff --git a/api/models/model.py b/api/models/model.py index 93737043d5..caad2ebe35 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -51,6 +51,7 @@ class AppMode(StrEnum): ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" CHANNEL = "channel" + RAG_PIPELINE = "rag-pipeline" @classmethod def value_of(cls, value: str) -> "AppMode": diff --git a/api/models/oauth.py b/api/models/oauth.py new file mode 100644 index 0000000000..84bc29931e --- /dev/null +++ b/api/models/oauth.py @@ -0,0 +1,38 @@ +from datetime import datetime + +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped + +from .base import Base +from .engine import db +from .types import StringUUID + + +class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] + __tablename__ = "datasource_oauth_params" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + + +class DatasourceProvider(Base): + __tablename__ = "datasource_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), + db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"), + ) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) + auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) + encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) diff --git a/api/models/tools.py b/api/models/tools.py index 03fbc3acb1..7daa06834f 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -51,6 +51,40 @@ class BuiltinToolProvider(Base): return cast(dict, json.loads(self.encrypted_credentials)) +class BuiltinDatasourceProvider(Base): + """ + This table stores the datasource provider information for built-in datasources for each tenant. + """ + + __tablename__ = "tool_builtin_datasource_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_builtin_datasource_provider_pkey"), + # one tenant can only have one tool provider with the same name + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_datasource_provider"), + ) + + # id of the tool provider + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # id of the tenant + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) + # who created this tool provider + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # name of the tool provider + provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + # credential of the tool provider + encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + + @property + def credentials(self) -> dict: + return cast(dict, json.loads(self.encrypted_credentials)) + + class ApiToolProvider(Base): """ The table stores the api providers. diff --git a/api/models/workflow.py b/api/models/workflow.py index 7f01135af3..638885be8d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -50,6 +50,7 @@ class WorkflowType(Enum): WORKFLOW = "workflow" CHAT = "chat" + RAG_PIPELINE = "rag-pipeline" @classmethod def value_of(cls, value: str) -> "WorkflowType": @@ -145,6 +146,9 @@ class Workflow(Base): _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) + _rag_pipeline_variables: Mapped[str] = mapped_column( + "rag_pipeline_variables", db.Text, nullable=False, server_default="{}" + ) VERSION_DRAFT = "draft" @@ -161,6 +165,7 @@ class Workflow(Base): created_by: str, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", ) -> "Workflow": @@ -175,6 +180,7 @@ class Workflow(Base): workflow.created_by = created_by workflow.environment_variables = environment_variables or [] workflow.conversation_variables = conversation_variables or [] + workflow.rag_pipeline_variables = rag_pipeline_variables or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment workflow.created_at = datetime.now(UTC).replace(tzinfo=None) @@ -316,6 +322,12 @@ class Workflow(Base): return variables + def rag_pipeline_user_input_form(self) -> list: + # get user_input_form from start node + variables: list[Any] = self.rag_pipeline_variables + + return variables + @property def unique_hash(self) -> str: """ @@ -432,6 +444,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], + "rag_pipeline_variables": self.rag_pipeline_variables, } return result @@ -452,6 +465,23 @@ class Workflow(Base): ensure_ascii=False, ) + @property + def rag_pipeline_variables(self) -> list[dict]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._rag_pipeline_variables is None: + self._rag_pipeline_variables = "{}" + + variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) + results = list(variables_dict.values()) + return results + + @rag_pipeline_variables.setter + def rag_pipeline_variables(self, values: list[dict]) -> None: + self._rag_pipeline_variables = json.dumps( + {item["variable"]: item for item in values}, + ensure_ascii=False, + ) + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) @@ -616,6 +646,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum): SINGLE_STEP = "single-step" WORKFLOW_RUN = "workflow-run" + RAG_PIPELINE_RUN = "rag-pipeline-run" class WorkflowNodeExecutionModel(Base): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e42b5ace75..09dced8dba 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -40,6 +40,7 @@ from models.dataset import ( Document, DocumentSegment, ExternalKnowledgeBindings, + Pipeline, ) from models.model import UploadFile from models.source import DataSourceOauthBinding @@ -50,6 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) from services.errors.account import InvalidActionError, NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError @@ -62,6 +67,7 @@ from services.vector_service import VectorService from tasks.add_document_to_index_task import add_document_to_index_task from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task +from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task @@ -238,6 +244,45 @@ class DatasetService: db.session.commit() return dataset + @staticmethod + def create_empty_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): + raise DatasetNameDuplicateError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + pipeline = Pipeline( + tenant_id=tenant_id, + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + created_by=current_user.id, + ) + db.session.add(pipeline) + db.session.flush() + + dataset = Dataset( + tenant_id=tenant_id, + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag_pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), + created_by=current_user.id, + pipeline_id=pipeline.id, + ) + db.session.add(dataset) + db.session.commit() + return dataset + @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() @@ -316,6 +361,17 @@ class DatasetService: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") + # check if dataset name is exists + if ( + db.session.query(Dataset) + .filter( + Dataset.id != dataset_id, + Dataset.name == data.get("name", dataset.name), + Dataset.tenant_id == dataset.tenant_id, + ) + .first() + ): + raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) @@ -431,6 +487,9 @@ class DatasetService: filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] + # update icon info + if data.get("icon_info"): + filtered_data["icon_info"] = data.get("icon_info") # Update dataset in database db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) @@ -623,6 +682,128 @@ class DatasetService: ) filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod + def update_rag_pipeline_dataset_settings( + session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False + ): + dataset = session.merge(dataset) + if not has_published: + dataset.chunk_structure = knowledge_configuration.chunk_structure + dataset.indexing_technique = knowledge_configuration.indexing_technique + if knowledge_configuration.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_configuration.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_configuration.embedding_model, + ) + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + dataset.collection_binding_id = dataset_collection_binding.id + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number + else: + raise ValueError("Invalid index method") + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + session.add(dataset) + else: + if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: + raise ValueError("Chunk structure is not allowed to be updated.") + action = None + if dataset.indexing_technique != knowledge_configuration.indexing_technique: + # if update indexing_technique + if knowledge_configuration.indexing_technique == "economy": + raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") + elif knowledge_configuration.indexing_technique == "high_quality": + action = "add" + # get embedding model setting + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_configuration.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_configuration.embedding_model, + ) + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + dataset.collection_binding_id = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + else: + # add default plugin id to both setting sets, to make sure the plugin model provider is consistent + # Skip embedding model checks if not provided in the update request + if dataset.indexing_technique == "high_quality": + skip_embedding_update = False + try: + # Handle existing model provider + plugin_model_provider = dataset.embedding_model_provider + plugin_model_provider_str = None + if plugin_model_provider: + plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + + # Handle new model provider from request + new_plugin_model_provider = knowledge_configuration.embedding_model_provider + new_plugin_model_provider_str = None + if new_plugin_model_provider: + new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + + # Only update embedding model if both values are provided and different from current + if ( + plugin_model_provider_str != new_plugin_model_provider_str + or knowledge_configuration.embedding_model != dataset.embedding_model + ): + action = "update" + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_configuration.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_configuration.embedding_model, + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, skip updating it + # and keep the existing settings if available + # Skip the rest of the embedding model update + skip_embedding_update = True + if not skip_embedding_update: + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + ) + dataset.collection_binding_id = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + elif dataset.indexing_technique == "economy": + if dataset.keyword_number != knowledge_configuration.keyword_number: + dataset.keyword_number = knowledge_configuration.keyword_number + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + session.add(dataset) + session.commit() + if action: + deal_dataset_index_update_task.delay(dataset.id, action) + @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) @@ -1359,6 +1540,283 @@ class DocumentService: return documents, batch + # @staticmethod + # def save_document_with_dataset_id( + # dataset: Dataset, + # knowledge_config: KnowledgeConfig, + # account: Account | Any, + # dataset_process_rule: Optional[DatasetProcessRule] = None, + # created_from: str = "web", + # ): + # # check document limit + # features = FeatureService.get_features(current_user.current_tenant_id) + + # if features.billing.enabled: + # if not knowledge_config.original_document_id: + # count = 0 + # if knowledge_config.data_source: + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + # # type: ignore + # count = len(upload_file_list) + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list + # for notion_info in notion_info_list: # type: ignore + # count = count + len(notion_info.pages) + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + # website_info = knowledge_config.data_source.info_list.website_info_list + # count = len(website_info.urls) # type: ignore + # batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + + # if features.billing.subscription.plan == "sandbox" and count > 1: + # raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + # if count > batch_upload_limit: + # raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + # DocumentService.check_documents_upload_quota(count, features) + + # # if dataset is empty, update dataset data_source_type + # if not dataset.data_source_type: + # dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + + # if not dataset.indexing_technique: + # if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + # raise ValueError("Indexing technique is invalid") + + # dataset.indexing_technique = knowledge_config.indexing_technique + # if knowledge_config.indexing_technique == "high_quality": + # model_manager = ModelManager() + # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + # dataset_embedding_model = knowledge_config.embedding_model + # dataset_embedding_model_provider = knowledge_config.embedding_model_provider + # else: + # embedding_model = model_manager.get_default_model_instance( + # tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + # ) + # dataset_embedding_model = embedding_model.model + # dataset_embedding_model_provider = embedding_model.provider + # dataset.embedding_model = dataset_embedding_model + # dataset.embedding_model_provider = dataset_embedding_model_provider + # dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + # dataset_embedding_model_provider, dataset_embedding_model + # ) + # dataset.collection_binding_id = dataset_collection_binding.id + # if not dataset.retrieval_model: + # default_retrieval_model = { + # "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + # "reranking_enable": False, + # "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + # "top_k": 2, + # "score_threshold_enabled": False, + # } + + # dataset.retrieval_model = ( + # knowledge_config.retrieval_model.model_dump() + # if knowledge_config.retrieval_model + # else default_retrieval_model + # ) # type: ignore + + # documents = [] + # if knowledge_config.original_document_id: + # document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + # documents.append(document) + # batch = document.batch + # else: + # batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + # # save process rule + # if not dataset_process_rule: + # process_rule = knowledge_config.process_rule + # if process_rule: + # if process_rule.mode in ("custom", "hierarchical"): + # dataset_process_rule = DatasetProcessRule( + # dataset_id=dataset.id, + # mode=process_rule.mode, + # rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + # created_by=account.id, + # ) + # elif process_rule.mode == "automatic": + # dataset_process_rule = DatasetProcessRule( + # dataset_id=dataset.id, + # mode=process_rule.mode, + # rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + # created_by=account.id, + # ) + # else: + # logging.warn( + # f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + # ) + # return + # db.session.add(dataset_process_rule) + # db.session.commit() + # lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + # with redis_client.lock(lock_name, timeout=600): + # position = DocumentService.get_documents_position(dataset.id) + # document_ids = [] + # duplicate_document_ids = [] + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + # for file_id in upload_file_list: + # file = ( + # db.session.query(UploadFile) + # .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + # .first() + # ) + + # # raise error if file not found + # if not file: + # raise FileNotExistsError() + + # file_name = file.name + # data_source_info = { + # "upload_file_id": file_id, + # } + # # check duplicate + # if knowledge_config.duplicate: + # document = Document.query.filter_by( + # dataset_id=dataset.id, + # tenant_id=current_user.current_tenant_id, + # data_source_type="upload_file", + # enabled=True, + # name=file_name, + # ).first() + # if document: + # document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + # document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # document.created_from = created_from + # document.doc_form = knowledge_config.doc_form + # document.doc_language = knowledge_config.doc_language + # document.data_source_info = json.dumps(data_source_info) + # document.batch = batch + # document.indexing_status = "waiting" + # db.session.add(document) + # documents.append(document) + # duplicate_document_ids.append(document.id) + # continue + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # file_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + # if not notion_info_list: + # raise ValueError("No notion info list found.") + # exist_page_ids = [] + # exist_document = {} + # documents = Document.query.filter_by( + # dataset_id=dataset.id, + # tenant_id=current_user.current_tenant_id, + # data_source_type="notion_import", + # enabled=True, + # ).all() + # if documents: + # for document in documents: + # data_source_info = json.loads(document.data_source_info) + # exist_page_ids.append(data_source_info["notion_page_id"]) + # exist_document[data_source_info["notion_page_id"]] = document.id + # for notion_info in notion_info_list: + # workspace_id = notion_info.workspace_id + # data_source_binding = DataSourceOauthBinding.query.filter( + # db.and_( + # DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + # DataSourceOauthBinding.provider == "notion", + # DataSourceOauthBinding.disabled == False, + # DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + # ) + # ).first() + # if not data_source_binding: + # raise ValueError("Data source binding not found.") + # for page in notion_info.pages: + # if page.page_id not in exist_page_ids: + # data_source_info = { + # "notion_workspace_id": workspace_id, + # "notion_page_id": page.page_id, + # "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + # "type": page.type, + # } + # # Truncate page name to 255 characters to prevent DB field length errors + # truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # truncated_page_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # else: + # exist_document.pop(page.page_id) + # # delete not selected documents + # if len(exist_document) > 0: + # clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + # website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + # if not website_info: + # raise ValueError("No website info list found.") + # urls = website_info.urls + # for url in urls: + # data_source_info = { + # "url": url, + # "provider": website_info.provider, + # "job_id": website_info.job_id, + # "only_main_content": website_info.only_main_content, + # "mode": "crawl", + # } + # if len(url) > 255: + # document_name = url[:200] + "..." + # else: + # document_name = url + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # document_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # db.session.commit() + + # # trigger async task + # if document_ids: + # document_indexing_task.delay(dataset.id, document_ids) + # if duplicate_document_ids: + # duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + # return documents, batch + @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size @@ -1370,7 +1828,7 @@ class DocumentService: @staticmethod def build_document( dataset: Dataset, - process_rule_id: str, + process_rule_id: str | None, data_source_type: str, document_form: str, document_language: str, diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py new file mode 100644 index 0000000000..228c18b7c2 --- /dev/null +++ b/api/services/datasource_provider_service.py @@ -0,0 +1,235 @@ +import logging + +from flask_login import current_user + +from constants import HIDDEN_VALUE +from core.helper import encrypter +from core.model_runtime.entities.provider_entities import FormType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.impl.datasource import PluginDatasourceManager +from extensions.ext_database import db +from models.oauth import DatasourceProvider + +logger = logging.getLogger(__name__) + + +class DatasourceProviderService: + """ + Model Provider Service + """ + + def __init__(self) -> None: + self.provider_manager = PluginDatasourceManager() + + def datasource_provider_credentials_validate( + self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str + ) -> None: + """ + validate datasource provider credentials. + + :param tenant_id: + :param provider: + :param credentials: + """ + # check name is exist + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first() + if datasource_provider: + raise ValueError("Authorization name is already exists") + + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials, + ) + if credential_valid: + # Get all provider configurations of the current workspace + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key") + .first() + ) + + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + name=name, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials, + ) + db.session.add(datasource_provider) + db.session.commit() + else: + raise CredentialsValidateFailedError() + + def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: + """ + Extract secret input form variables. + + :param credential_form_schemas: + :return: + """ + datasource_provider = self.provider_manager.fetch_datasource_provider( + tenant_id=tenant_id, provider_id=provider_id + ) + credential_form_schemas = datasource_provider.declaration.credentials_schema + secret_input_form_variables = [] + for credential_form_schema in credential_form_schemas: + if credential_form_schema.type == FormType.SECRET_INPUT: + secret_input_form_variables.append(credential_form_schema.name) + + return secret_input_form_variables + + def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider_id: provider id + :return: + """ + # Get all provider configurations of the current workspace + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + "name": datasource_provider.name, + } + ) + + return copy_credentials_list + + def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider_id: provider id + :return: + """ + # Get all provider configurations of the current workspace + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + } + ) + + return copy_credentials_list + + def update_datasource_credentials( + self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict + ) -> None: + """ + update datasource credentials. + """ + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials, + ) + if credential_valid: + # Get all provider configurations of the current workspace + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) + .first() + ) + + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) + if not datasource_provider: + raise ValueError("Datasource provider not found") + else: + original_credentials = datasource_provider.encrypted_credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) + credentials[key] = encrypter.encrypt_token(tenant_id, original_value) + else: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + datasource_provider.encrypted_credentials = credentials + db.session.commit() + else: + raise CredentialsValidateFailedError() + + def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: + """ + remove datasource credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param plugin_id: plugin id + :return: + """ + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) + .first() + ) + if datasource_provider: + db.session.delete(datasource_provider) + db.session.commit() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py new file mode 100644 index 0000000000..620fb2426a --- /dev/null +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -0,0 +1,116 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + + +class IconInfo(BaseModel): + icon: str + icon_background: Optional[str] = None + icon_type: Optional[str] = None + icon_url: Optional[str] = None + + +class PipelineTemplateInfoEntity(BaseModel): + name: str + description: str + icon_info: IconInfo + + +class RagPipelineDatasetCreateEntity(BaseModel): + name: str + description: str + icon_info: IconInfo + permission: str + partial_member_list: Optional[list[str]] = None + yaml_content: Optional[str] = None + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + reranking_provider_name: str + reranking_model_name: str + + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class RetrievalSetting(BaseModel): + """ + Retrieval Setting. + """ + + search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"] + top_k: int + score_threshold: Optional[float] = 0.5 + score_threshold_enabled: bool = False + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None + + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting + + +class KnowledgeConfiguration(BaseModel): + """ + Knowledge Base Configuration. + """ + + chunk_structure: str + indexing_technique: Literal["high_quality", "economy"] + embedding_model_provider: Optional[str] = "" + embedding_model: Optional[str] = "" + keyword_number: Optional[int] = 10 + retrieval_model: RetrievalSetting diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py new file mode 100644 index 0000000000..da67801877 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -0,0 +1,99 @@ +from collections.abc import Mapping +from typing import Any, Union + +from configs import dify_config +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from models.dataset import Pipeline +from models.model import Account, App, EndUser +from models.workflow import Workflow +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class PipelineGenerateService: + @classmethod + def generate( + cls, + pipeline: Pipeline, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Pipeline Content Generate + :param pipeline: pipeline + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + try: + workflow = cls._get_workflow(pipeline, invoke_from) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().generate( + pipeline=pipeline, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + workflow_thread_pool_id=None, + ), + ) + + except Exception: + raise + + @staticmethod + def _get_max_active_requests(app_model: App) -> int: + max_active_requests = app_model.max_active_requests + if max_active_requests is None: + max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) + return max_active_requests + + @classmethod + def generate_single_iteration( + cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True + ): + workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().single_iteration_generate( + pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + + @classmethod + def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): + workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().single_loop_generate( + pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + + @classmethod + def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow: + """ + Get workflow + :param pipeline: pipeline + :param invoke_from: invoke from + :return: + """ + rag_pipeline_service = RagPipelineService() + if invoke_from == InvokeFrom.DEBUGGER: + # fetch draft workflow by app_model + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + if not workflow: + raise ValueError("Workflow not initialized") + else: + # fetch published workflow by app_model + workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) + + if not workflow: + raise ValueError("Workflow not published") + + return workflow diff --git a/api/services/rag_pipeline/pipeline_template/__init__.py b/api/services/rag_pipeline/pipeline_template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/built_in/__init__.py b/api/services/rag_pipeline/pipeline_template/built_in/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py new file mode 100644 index 0000000000..b0fa54115c --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -0,0 +1,64 @@ +import json +from os import path +from pathlib import Path +from typing import Optional + +from flask import current_app + +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType + + +class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json + """ + + builtin_data: Optional[dict] = None + + def get_type(self) -> str: + return PipelineTemplateType.BUILTIN + + def get_pipeline_templates(self, language: str) -> dict: + result = self.fetch_pipeline_templates_from_builtin(language) + return result + + def get_pipeline_template_detail(self, template_id: str): + result = self.fetch_pipeline_template_detail_from_builtin(template_id) + return result + + @classmethod + def _get_builtin_data(cls) -> dict: + """ + Get builtin data. + :return: + """ + if cls.builtin_data: + return cls.builtin_data + + root_path = current_app.root_path + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "pipeline_templates.json")).read_text(encoding="utf-8") + ) + + return cls.builtin_data or {} + + @classmethod + def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict: + """ + Fetch pipeline templates from builtin. + :param language: language + :return: + """ + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("pipeline_templates", {}).get(language, {}) + + @classmethod + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]: + """ + Fetch pipeline template detail from builtin. + :param template_id: Template ID + :return: + """ + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/__init__.py b/api/services/rag_pipeline/pipeline_template/customized/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py new file mode 100644 index 0000000000..3380d23ec4 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -0,0 +1,83 @@ +from typing import Optional + +import yaml +from flask_login import current_user + +from extensions.ext_database import db +from models.dataset import PipelineCustomizedTemplate +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType + + +class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval recommended app from database + """ + + def get_pipeline_templates(self, language: str) -> dict: + result = self.fetch_pipeline_templates_from_customized( + tenant_id=current_user.current_tenant_id, language=language + ) + return result + + def get_pipeline_template_detail(self, template_id: str): + result = self.fetch_pipeline_template_detail_from_db(template_id) + return result + + def get_type(self) -> str: + return PipelineTemplateType.CUSTOMIZED + + @classmethod + def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict: + """ + Fetch pipeline templates from db. + :param tenant_id: tenant id + :param language: language + :return: + """ + pipeline_customized_templates = ( + db.session.query(PipelineCustomizedTemplate) + .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) + .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) + .all() + ) + recommended_pipelines_results = [] + for pipeline_customized_template in pipeline_customized_templates: + recommended_pipeline_result = { + "id": pipeline_customized_template.id, + "name": pipeline_customized_template.name, + "description": pipeline_customized_template.description, + "icon": pipeline_customized_template.icon, + "position": pipeline_customized_template.position, + "chunk_structure": pipeline_customized_template.chunk_structure, + } + recommended_pipelines_results.append(recommended_pipeline_result) + + return {"pipeline_templates": recommended_pipelines_results} + + @classmethod + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: + """ + Fetch pipeline template detail from db. + :param template_id: Template ID + :return: + """ + pipeline_template = ( + db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + ) + if not pipeline_template: + return None + + dsl_data = yaml.safe_load(pipeline_template.yaml_content) + graph_data = dsl_data.get("workflow", {}).get("graph", {}) + + return { + "id": pipeline_template.id, + "name": pipeline_template.name, + "icon_info": pipeline_template.icon, + "description": pipeline_template.description, + "chunk_structure": pipeline_template.chunk_structure, + "export_data": pipeline_template.yaml_content, + "graph": graph_data, + "created_by": pipeline_template.created_user_name, + } diff --git a/api/services/rag_pipeline/pipeline_template/database/__init__.py b/api/services/rag_pipeline/pipeline_template/database/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py new file mode 100644 index 0000000000..b69a857c3a --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -0,0 +1,80 @@ +from typing import Optional + +import yaml + +from extensions.ext_database import db +from models.dataset import PipelineBuiltInTemplate +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType + + +class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval pipeline template from database + """ + + def get_pipeline_templates(self, language: str) -> dict: + result = self.fetch_pipeline_templates_from_db(language) + return result + + def get_pipeline_template_detail(self, pipeline_id: str): + result = self.fetch_pipeline_template_detail_from_db(pipeline_id) + return result + + def get_type(self) -> str: + return PipelineTemplateType.DATABASE + + @classmethod + def fetch_pipeline_templates_from_db(cls, language: str) -> dict: + """ + Fetch pipeline templates from db. + :param language: language + :return: + """ + + pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( + db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() + ) + + recommended_pipelines_results = [] + for pipeline_built_in_template in pipeline_built_in_templates: + recommended_pipeline_result = { + "id": pipeline_built_in_template.id, + "name": pipeline_built_in_template.name, + "description": pipeline_built_in_template.description, + "icon": pipeline_built_in_template.icon, + "copyright": pipeline_built_in_template.copyright, + "privacy_policy": pipeline_built_in_template.privacy_policy, + "position": pipeline_built_in_template.position, + "chunk_structure": pipeline_built_in_template.chunk_structure, + } + recommended_pipelines_results.append(recommended_pipeline_result) + + return {"pipeline_templates": recommended_pipelines_results} + + @classmethod + def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: + """ + Fetch pipeline template detail from db. + :param pipeline_id: Pipeline ID + :return: + """ + # is in public recommended list + pipeline_template = ( + db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() + ) + + if not pipeline_template: + return None + dsl_data = yaml.safe_load(pipeline_template.yaml_content) + graph_data = dsl_data.get("workflow", {}).get("graph", {}) + return { + "id": pipeline_template.id, + "name": pipeline_template.name, + "icon_info": pipeline_template.icon, + "description": pipeline_template.description, + "chunk_structure": pipeline_template.chunk_structure, + "export_data": pipeline_template.yaml_content, + "graph": graph_data, + "created_by": pipeline_template.created_user_name, + } diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py new file mode 100644 index 0000000000..fa6a38a357 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from typing import Optional + + +class PipelineTemplateRetrievalBase(ABC): + """Interface for pipeline template retrieval.""" + + @abstractmethod + def get_pipeline_templates(self, language: str) -> dict: + raise NotImplementedError + + @abstractmethod + def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]: + raise NotImplementedError + + @abstractmethod + def get_type(self) -> str: + raise NotImplementedError diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py new file mode 100644 index 0000000000..7b87ffe75b --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -0,0 +1,26 @@ +from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval +from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval +from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval + + +class PipelineTemplateRetrievalFactory: + @staticmethod + def get_pipeline_template_factory(mode: str) -> type[PipelineTemplateRetrievalBase]: + match mode: + case PipelineTemplateType.REMOTE: + return RemotePipelineTemplateRetrieval + case PipelineTemplateType.CUSTOMIZED: + return CustomizedPipelineTemplateRetrieval + case PipelineTemplateType.DATABASE: + return DatabasePipelineTemplateRetrieval + case PipelineTemplateType.BUILTIN: + return BuiltInPipelineTemplateRetrieval + case _: + raise ValueError(f"invalid fetch recommended apps mode: {mode}") + + @staticmethod + def get_built_in_pipeline_template_retrieval(): + return BuiltInPipelineTemplateRetrieval diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py new file mode 100644 index 0000000000..e914266d26 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py @@ -0,0 +1,8 @@ +from enum import StrEnum + + +class PipelineTemplateType(StrEnum): + REMOTE = "remote" + DATABASE = "database" + CUSTOMIZED = "customized" + BUILTIN = "builtin" diff --git a/api/services/rag_pipeline/pipeline_template/remote/__init__.py b/api/services/rag_pipeline/pipeline_template/remote/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py new file mode 100644 index 0000000000..5553d7c97e --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -0,0 +1,68 @@ +import logging +from typing import Optional + +import requests + +from configs import dify_config +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval + +logger = logging.getLogger(__name__) + + +class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval recommended app from dify official + """ + + def get_pipeline_template_detail(self, pipeline_id: str): + try: + result = self.fetch_pipeline_template_detail_from_dify_official(pipeline_id) + except Exception as e: + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(pipeline_id) + return result + + def get_pipeline_templates(self, language: str) -> dict: + try: + result = self.fetch_pipeline_templates_from_dify_official(language) + except Exception as e: + logger.warning(f"fetch pipeline templates from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) + return result + + def get_type(self) -> str: + return PipelineTemplateType.REMOTE + + @classmethod + def fetch_pipeline_template_detail_from_dify_official(cls, pipeline_id: str) -> Optional[dict]: + """ + Fetch pipeline template detail from dify official. + :param pipeline_id: Pipeline ID + :return: + """ + domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/pipelines/{pipeline_id}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + return None + data: dict = response.json() + return data + + @classmethod + def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict: + """ + Fetch pipeline templates from dify official. + :param language: language + :return: + """ + domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/pipelines?language={language}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}") + + result: dict = response.json() + + return result diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py new file mode 100644 index 0000000000..0e1fad600f --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -0,0 +1,1059 @@ +import json +import logging +import re +import threading +import time +from collections.abc import Callable, Generator, Mapping, Sequence +from datetime import UTC, datetime +from typing import Any, Optional, cast +from uuid import uuid4 + +from flask_login import current_user +from sqlalchemy import func, or_, select +from sqlalchemy.orm import Session + +import contexts +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import ( + DatasourceMessage, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + OnlineDocumentPagesMessage, + WebsiteCrawlMessage, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin +from core.rag.entities.event import ( + BaseDatasourceEvent, + DatasourceCompletedEvent, + DatasourceErrorEvent, + DatasourceProcessingEvent, +) +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.variables.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from core.workflow.enums import SystemVariableKey +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event.event import RunCompletedEvent +from core.workflow.nodes.event.types import NodeEvent +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.account import Account +from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore +from models.enums import WorkflowRunTriggeredFrom +from models.model import EndUser +from models.workflow import ( + Workflow, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowType, +) +from services.dataset_service import DatasetService +from services.datasource_provider_service import DatasourceProviderService +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + PipelineTemplateInfoEntity, +) +from services.errors.app import WorkflowHashNotEqualError +from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory + +logger = logging.getLogger(__name__) + + +class RagPipelineService: + @classmethod + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: + if type == "built-in": + mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result = retrieval_instance.get_pipeline_templates(language) + if not result.get("pipeline_templates") and language != "en-US": + template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() + result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") + return result + else: + mode = "customized" + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result = retrieval_instance.get_pipeline_templates(language) + return result + + @classmethod + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]: + """ + Get pipeline template detail. + :param template_id: template id + :return: + """ + if type == "built-in": + mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + else: + mode = "customized" + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + return result + + @classmethod + def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity): + """ + Update pipeline template. + :param template_id: template id + :param template_info: template info + """ + customized_template: PipelineCustomizedTemplate | None = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.id == template_id, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + ) + .first() + ) + if not customized_template: + raise ValueError("Customized pipeline template not found.") + # check template name is exist + template_name = template_info.name + if template_name: + template = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.id != template_id, + ) + .first() + ) + if template: + raise ValueError("Template name is already exists") + customized_template.name = template_info.name + customized_template.description = template_info.description + customized_template.icon = template_info.icon_info.model_dump() + customized_template.updated_by = current_user.id + db.session.commit() + return customized_template + + @classmethod + def delete_customized_pipeline_template(cls, template_id: str): + """ + Delete customized pipeline template. + """ + customized_template: PipelineCustomizedTemplate | None = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.id == template_id, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + ) + .first() + ) + if not customized_template: + raise ValueError("Customized pipeline template not found.") + db.session.delete(customized_template) + db.session.commit() + + def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + """ + Get draft workflow + """ + # fetch draft workflow by rag pipeline + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() + ) + + # return draft workflow + return workflow + + def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + """ + Get published workflow + """ + + if not pipeline.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.id == pipeline.workflow_id, + ) + .first() + ) + + return workflow + + def get_all_published_workflow( + self, + *, + session: Session, + pipeline: Pipeline, + page: int, + limit: int, + user_id: str | None, + named_only: bool = False, + ) -> tuple[Sequence[Workflow], bool]: + """ + Get published workflow with pagination + """ + if not pipeline.workflow_id: + return [], False + + stmt = ( + select(Workflow) + .where(Workflow.app_id == pipeline.id) + .order_by(Workflow.version.desc()) + .limit(limit + 1) + .offset((page - 1) * limit) + ) + + if user_id: + stmt = stmt.where(Workflow.created_by == user_id) + + if named_only: + stmt = stmt.where(Workflow.marked_name != "") + + workflows = session.scalars(stmt).all() + + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + + def sync_draft_workflow( + self, + *, + pipeline: Pipeline, + graph: dict, + unique_hash: Optional[str], + account: Account, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + rag_pipeline_variables: list, + ) -> Workflow: + """ + Sync draft workflow + :raises WorkflowHashNotEqualError + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(pipeline=pipeline) + + if workflow and workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + features="{}", + type=WorkflowType.RAG_PIPELINE.value, + version="draft", + graph=json.dumps(graph), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, + ) + db.session.add(workflow) + db.session.flush() + pipeline.workflow_id = workflow.id + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + workflow.rag_pipeline_variables = rag_pipeline_variables + # commit db session changes + db.session.commit() + + # trigger workflow events TODO + # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow) + + # return draft workflow + return workflow + + def publish_workflow( + self, + *, + session: Session, + pipeline: Pipeline, + account: Account, + ) -> Workflow: + draft_workflow_stmt = select(Workflow).where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + draft_workflow = session.scalar(draft_workflow_stmt) + if not draft_workflow: + raise ValueError("No valid workflow found.") + + # create new workflow + workflow = Workflow.new( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + type=draft_workflow.type, + version=str(datetime.now(UTC).replace(tzinfo=None)), + graph=draft_workflow.graph, + features=draft_workflow.features, + created_by=account.id, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + marked_name="", + marked_comment="", + ) + # commit db session changes + session.add(workflow) + + graph = workflow.graph_dict + nodes = graph.get("nodes", []) + for node in nodes: + if node.get("data", {}).get("type") == "knowledge-index": + knowledge_configuration = node.get("data", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + + # update dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + DatasetService.update_rag_pipeline_dataset_settings( + session=session, + dataset=dataset, + knowledge_configuration=knowledge_configuration, + has_published=pipeline.is_published, + ) + # return new workflow + return workflow + + def get_default_block_configs(self) -> list[dict]: + """ + Get default block configs + """ + # return default block config + default_block_configs = [] + for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): + node_class = node_class_mapping[LATEST_VERSION] + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs + + def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_type_enum = NodeType(node_type) + + # return default block config + if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + return None + + node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config + + def run_draft_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ), + start_at=start_at, + tenant_id=pipeline.tenant_id, + node_id=node_id, + ) + workflow_node_execution.workflow_id = draft_workflow.id + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def run_published_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run published workflow node + """ + # fetch published workflow by app_model + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.single_step_run( + workflow=published_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ), + start_at=start_at, + tenant_id=pipeline.tenant_id, + node_id=node_id, + ) + + workflow_node_execution.workflow_id = published_workflow.id + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def run_datasource_workflow_node( + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, + ) -> Generator[BaseDatasourceEvent, None, None]: + """ + Run published workflow datasource + """ + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( + datasource_runtime.get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + try: + for message in online_document_result: + end_time = time.time() + online_document_event = DatasourceCompletedEvent( + data=message.result, time_consuming=round(end_time - start_time, 2) + ) + yield online_document_event.model_dump() + except Exception as e: + logger.exception("Error during online document.") + yield DatasourceErrorEvent(error=str(e)).model_dump() + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( + datasource_runtime.get_website_crawl( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + start_time = time.time() + try: + for message in website_crawl_result: + end_time = time.time() + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2), + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() + except Exception as e: + logger.exception("Error during website crawl.") + yield DatasourceErrorEvent(error=str(e)).model_dump() + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_workflow_node.") + yield DatasourceErrorEvent(error=str(e)).model_dump() + + def run_datasource_node_preview( + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, + ) -> Mapping[str, Any]: + """ + Run published workflow datasource + """ + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( + user_id=account.id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=user_inputs.get("workspace_id"), + page_id=user_inputs.get("page_id"), + type=user_inputs.get("type"), + ), + provider_type=datasource_type, + ) + ) + try: + variables: dict[str, Any] = {} + for message in online_document_result: + if message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + else: + variables[variable_name] = variable_value + return variables + except Exception as e: + logger.exception("Error during get online document content.") + raise RuntimeError(str(e)) + # TODO Online Drive + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_node_preview.") + raise RuntimeError(str(e)) + + def run_free_workflow_node( + self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + ) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.run_free_node( + node_id=node_id, + node_data=node_data, + tenant_id=tenant_id, + user_id=user_id, + user_inputs=user_inputs, + ), + start_at=start_at, + tenant_id=tenant_id, + node_id=node_id, + ) + + return workflow_node_execution + + def _handle_node_run_result( + self, + getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], + start_at: float, + tenant_id: str, + node_id: str, + ) -> WorkflowNodeExecution: + """ + Handle node run result + + :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]] + :param start_at: float + :param tenant_id: str + :param node_id: str + """ + try: + node_instance, generator = getter() + + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args: dict[str, Any] = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) + error = node_run_result.error if not run_succeeded else None + except WorkflowNodeRunFailedError as e: + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error + + workflow_node_execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=node_instance.workflow_id, + index=1, + node_id=node_id, + node_type=node_instance.node_type, + title=node_instance.node_data.title, + elapsed_time=time.perf_counter() - start_at, + finished_at=datetime.now(UTC).replace(tzinfo=None), + created_at=datetime.now(UTC).replace(tzinfo=None), + ) + if run_succeeded and node_run_result: + # create workflow node execution + inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + process_data = ( + WorkflowEntry.handle_special_values(node_run_result.process_data) + if node_run_result.process_data + else None + ) + outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + + workflow_node_execution.inputs = inputs + workflow_node_execution.process_data = process_data + workflow_node_execution.outputs = outputs + workflow_node_execution.metadata = node_run_result.metadata + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION + workflow_node_execution.error = node_run_result.error + else: + # create workflow node execution + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED + workflow_node_execution.error = error + # update document status + variable_pool = node_instance.graph_runtime_state.variable_pool + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + if invoke_from: + if invoke_from.value == InvokeFrom.PUBLISHED.value: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter(Document.id == document_id.value).first() + if document: + document.indexing_status = "error" + document.error = error + db.session.add(document) + db.session.commit() + + return workflow_node_execution + + def update_workflow( + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + ) -> Optional[Workflow]: + """ + Update workflow attributes + + :param session: SQLAlchemy database session + :param workflow_id: Workflow ID + :param tenant_id: Tenant ID + :param account_id: Account ID (for permission check) + :param data: Dictionary containing fields to update + :return: Updated workflow or None if not found + """ + stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + workflow = session.scalar(stmt) + + if not workflow: + return None + + allowed_fields = ["marked_name", "marked_comment"] + + for field, value in data.items(): + if field in allowed_fields: + setattr(workflow, field, value) + + workflow.updated_by = account_id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + + return workflow + + def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: + """ + Get first step parameters of rag pipeline + """ + + workflow = ( + self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + ) + if not workflow: + raise ValueError("Workflow not initialized") + + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + variables = workflow.rag_pipeline_variables + if variables: + variables_map = {item["variable"]: item for item in variables} + else: + return [] + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + + user_input_variables = [] + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + user_input_variables.append(variables_map.get(last_part, {})) + return user_input_variables + + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: + """ + Get second step parameters of rag pipeline + """ + + workflow = ( + self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + ) + if not workflow: + raise ValueError("Workflow not initialized") + + # get second step node + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: + return [] + variables_map = {item["variable"]: item for item in rag_pipeline_variables} + + # get datasource node data + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if datasource_node_data: + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + variables_map.pop(last_part) + all_second_step_variables = list(variables_map.values()) + datasource_provider_variables = [ + item + for item in all_second_step_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] + return datasource_provider_variables + + def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get("limit", 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == pipeline.tenant_id, + WorkflowRun.app_id == pipeline.id, + or_( + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, + ), + ) + + if args.get("last_id"): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get("last_id"), + ).first() + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) + else: + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id, + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == pipeline.tenant_id, + WorkflowRun.app_id == pipeline.id, + WorkflowRun.id == run_id, + ) + .first() + ) + + return workflow_run + + def get_rag_pipeline_workflow_run_node_executions( + self, + pipeline: Pipeline, + run_id: str, + user: Account | EndUser, + ) -> list[WorkflowNodeExecutionModel]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + if not workflow_run: + return [] + + # Use the repository to get the node execution + repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None + ) + + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_db_models_by_workflow_run( + workflow_run_id=run_id, + order_config=order_config, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + + return list(node_executions) + + @classmethod + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + """ + Publish customized pipeline template + """ + pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + if not pipeline.workflow_id: + raise ValueError("Pipeline workflow not found") + workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + if not workflow: + raise ValueError("Workflow not found") + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + + # check template name is exist + template_name = args.get("name") + if template_name: + template = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, + ) + .first() + ) + if template: + raise ValueError("Template name is already exists") + + max_position = ( + db.session.query(func.max(PipelineCustomizedTemplate.position)) + .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) + .scalar() + ) + + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) + + pipeline_customized_template = PipelineCustomizedTemplate( + name=args.get("name"), + description=args.get("description"), + icon=args.get("icon_info"), + tenant_id=pipeline.tenant_id, + yaml_content=dsl, + position=max_position + 1 if max_position else 1, + chunk_structure=dataset.chunk_structure, + language="en-US", + created_by=current_user.id, + ) + db.session.add(pipeline_customized_template) + db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py new file mode 100644 index 0000000000..fb311482d8 --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -0,0 +1,895 @@ +import base64 +import hashlib +import json +import logging +import uuid +from collections.abc import Mapping +from datetime import UTC, datetime +from enum import StrEnum +from typing import Optional, cast +from urllib.parse import urlparse +from uuid import uuid4 + +import yaml # type: ignore +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad +from flask_login import current_user +from packaging import version +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import PluginDependency +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData +from core.workflow.nodes.tool.entities import ToolNodeData +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account +from models.dataset import Dataset, DatasetCollectionBinding, Pipeline +from models.workflow import Workflow, WorkflowType +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) +from services.plugin.dependencies_analysis import DependenciesAnalysisService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" +CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" +IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes +DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB +CURRENT_DSL_VERSION = "0.1.0" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class RagPipelineImportInfo(BaseModel): + id: str + status: ImportStatus + pipeline_id: Optional[str] = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + dataset_id: Optional[str] = None + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # If imported version is newer than current, always return PENDING + if imported_ver > current_ver: + return ImportStatus.PENDING + + # If imported version is older than current's major, return PENDING + if imported_ver.major < current_ver.major: + return ImportStatus.PENDING + + # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS + if imported_ver.minor < current_ver.minor: + return ImportStatus.COMPLETED_WITH_WARNINGS + + # If imported version equals or is older than current's micro, return COMPLETED + return ImportStatus.COMPLETED + + +class RagPipelinePendingData(BaseModel): + import_mode: str + yaml_content: str + pipeline_id: str | None + + +class CheckDependenciesPendingData(BaseModel): + dependencies: list[PluginDependency] + pipeline_id: str | None + + +class RagPipelineDslService: + def __init__(self, session: Session): + self._session = session + + def import_rag_pipeline( + self, + *, + account: Account, + import_mode: str, + yaml_content: Optional[str] = None, + yaml_url: Optional[str] = None, + pipeline_id: Optional[str] = None, + dataset: Optional[Dataset] = None, + ) -> RagPipelineImportInfo: + """Import an app from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content: str = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + parsed_url = urlparse(yaml_url) + if ( + parsed_url.scheme == "https" + and parsed_url.netloc == "github.com" + and parsed_url.path.endswith((".yml", ".yaml")) + ): + yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") + yaml_url = yaml_url.replace("/blob/", "/") + response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content.decode() + + if len(content) > DSL_MAX_SIZE: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="File size exceeds the limit of 10MB", + ) + + if not content: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Empty content from url", + ) + except Exception as e: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error fetching YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + + # Process YAML content + try: + # Parse YAML to validate format + data = yaml.safe_load(content) + if not isinstance(data, dict): + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: content must be a mapping", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + if not data.get("kind") or data.get("kind") != "rag_pipeline": + data["kind"] = "rag_pipeline" + + imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") + status = _check_version_compatibility(imported_version) + + # Extract app data + pipeline_data = data.get("rag_pipeline") + if not pipeline_data: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing rag_pipeline data in YAML content", + ) + + # If app_id is provided, check if it exists + pipeline = None + if pipeline_id: + stmt = select(Pipeline).where( + Pipeline.id == pipeline_id, + Pipeline.tenant_id == account.current_tenant_id, + ) + pipeline = self._session.scalar(stmt) + + if not pipeline: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Pipeline not found", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + pending_data = RagPipelinePendingData( + import_mode=import_mode, + yaml_content=content, + pipeline_id=pipeline_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending_data.model_dump_json(), + ) + + return RagPipelineImportInfo( + id=import_id, + status=status, + pipeline_id=pipeline_id, + imported_dsl_version=imported_version, + ) + + # Extract dependencies + dependencies = data.get("dependencies", []) + check_dependencies_pending_data = None + if dependencies: + check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] + + # Create or update pipeline + pipeline = self._create_or_update_pipeline( + pipeline=pipeline, + data=data, + account=account, + dependencies=check_dependencies_pending_data, + ) + # create dataset + name = pipeline.name + description = pipeline.description + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") + workflow = data.get("workflow", {}) + graph = workflow.get("graph", {}) + nodes = graph.get("nodes", []) + dataset_id = None + for node in nodes: + if node.get("data", {}).get("type") == "knowledge-index": + knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + if ( + dataset + and pipeline.is_published + and dataset.chunk_structure != knowledge_configuration.chunk_structure + ): + raise ValueError("Chunk structure is not compatible with the published pipeline") + else: + dataset = Dataset( + tenant_id=account.current_tenant_id, + name=name, + description=description, + icon_info={ + "type": icon_type, + "icon": icon, + "background": icon_background, + "url": icon_url, + }, + indexing_technique=knowledge_configuration.indexing_technique, + created_by=account.id, + retrieval_model=knowledge_configuration.retrieval_model.model_dump(), + runtime_mode="rag_pipeline", + chunk_structure=knowledge_configuration.chunk_structure, + ) + if knowledge_configuration.indexing_technique == "high_quality": + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name + == knowledge_configuration.embedding_model_provider, + DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, + DatasetCollectionBinding.type == "dataset", + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=knowledge_configuration.embedding_model_provider, + model_name=knowledge_configuration.embedding_model, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type="dataset", + ) + db.session.add(dataset_collection_binding) + db.session.commit() + dataset_collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding_id + dataset.embedding_model = knowledge_configuration.embedding_model + dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number + dataset.pipeline_id = pipeline.id + self._session.add(dataset) + self._session.commit() + dataset_id = dataset.id + if not dataset_id: + raise ValueError("DSL is not valid, please check the Knowledge Index node.") + + return RagPipelineImportInfo( + id=import_id, + status=status, + pipeline_id=pipeline.id, + dataset_id=dataset_id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import app") + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + pending_data = RagPipelinePendingData.model_validate_json(pending_data) + data = yaml.safe_load(pending_data.yaml_content) + + pipeline = None + if pending_data.pipeline_id: + stmt = select(Pipeline).where( + Pipeline.id == pending_data.pipeline_id, + Pipeline.tenant_id == account.current_tenant_id, + ) + pipeline = self._session.scalar(stmt) + + # Create or update app + pipeline = self._create_or_update_pipeline( + pipeline=pipeline, + data=data, + account=account, + ) + + # create dataset + name = pipeline.name + description = pipeline.description + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") + workflow = data.get("workflow", {}) + graph = workflow.get("graph", {}) + nodes = graph.get("nodes", []) + dataset_id = None + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + if not dataset: + dataset = Dataset( + tenant_id=account.current_tenant_id, + name=name, + description=description, + icon_info={ + "type": icon_type, + "icon": icon, + "background": icon_background, + "url": icon_url, + }, + indexing_technique=knowledge_configuration.indexing_technique, + created_by=account.id, + retrieval_model=knowledge_configuration.retrieval_model.model_dump(), + runtime_mode="rag_pipeline", + chunk_structure=knowledge_configuration.chunk_structure, + ) + else: + dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + dataset.runtime_mode = "rag_pipeline" + dataset.chunk_structure = knowledge_configuration.chunk_structure + if knowledge_configuration.indexing_technique == "high_quality": + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name + == knowledge_configuration.embedding_model_provider, + DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, + DatasetCollectionBinding.type == "dataset", + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=knowledge_configuration.embedding_model_provider, + model_name=knowledge_configuration.embedding_model, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type="dataset", + ) + db.session.add(dataset_collection_binding) + db.session.commit() + dataset_collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding_id + dataset.embedding_model = knowledge_configuration.embedding_model + dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number + dataset.pipeline_id = pipeline.id + self._session.add(dataset) + self._session.commit() + dataset_id = dataset.id + if not dataset_id: + raise ValueError("DSL is not valid, please check the Knowledge Index node.") + + # Delete import info from Redis + redis_client.delete(redis_key) + + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.COMPLETED, + pipeline_id=pipeline.id, + dataset_id=dataset_id, + current_dsl_version=CURRENT_DSL_VERSION, + imported_dsl_version=data.get("version", "0.1.0"), + ) + + except Exception as e: + logger.exception("Error confirming import") + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def check_dependencies( + self, + *, + pipeline: Pipeline, + ) -> CheckDependenciesResult: + """Check dependencies""" + # Get dependencies from Redis + redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}" + dependencies = redis_client.get(redis_key) + if not dependencies: + return CheckDependenciesResult() + + # Extract dependencies + dependencies = CheckDependenciesPendingData.model_validate_json(dependencies) + + # Get leaked dependencies + leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies + ) + return CheckDependenciesResult( + leaked_dependencies=leaked_dependencies, + ) + + def _create_or_update_pipeline( + self, + *, + pipeline: Optional[Pipeline], + data: dict, + account: Account, + dependencies: Optional[list[PluginDependency]] = None, + ) -> Pipeline: + """Create a new app or update an existing one.""" + pipeline_data = data.get("rag_pipeline", {}) + # Set icon type + icon_type_value = pipeline_data.get("icon_type") + if icon_type_value in ["emoji", "link"]: + icon_type = icon_type_value + else: + icon_type = "emoji" + icon = str(pipeline_data.get("icon", "")) + + # Initialize pipeline based on mode + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for rag pipeline") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) + + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, + tenant_id=account.current_tenant_id, + ) + ) + ] + + if pipeline: + # Update existing pipeline + pipeline.name = pipeline_data.get("name", pipeline.name) + pipeline.description = pipeline_data.get("description", pipeline.description) + pipeline.updated_by = account.id + + else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + + # Create new app + pipeline = Pipeline() + pipeline.id = str(uuid4()) + pipeline.tenant_id = account.current_tenant_id + pipeline.name = pipeline_data.get("name", "") + pipeline.description = pipeline_data.get("description", "") + pipeline.created_by = account.id + pipeline.updated_by = account.id + + self._session.add(pipeline) + self._session.commit() + # save dependencies + if dependencies: + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}", + IMPORT_INFO_REDIS_EXPIRY, + CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), + ) + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() + ) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + features="{}", + type=WorkflowType.RAG_PIPELINE.value, + version="draft", + graph=json.dumps(graph), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables_list, + ) + db.session.add(workflow) + db.session.flush() + pipeline.workflow_id = workflow.id + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + workflow.rag_pipeline_variables = rag_pipeline_variables_list + # commit db session changes + db.session.commit() + + return pipeline + + @classmethod + def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str: + """ + Export pipeline + :param pipeline: Pipeline instance + :param include_secret: Whether include secret variable + :return: + """ + dataset = pipeline.dataset + if not dataset: + raise ValueError("Missing dataset for rag pipeline") + icon_info = dataset.icon_info + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "rag_pipeline", + "rag_pipeline": { + "name": pipeline.name, + "icon": icon_info.get("icon", "📙") if icon_info else "📙", + "icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji", + "icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5", + "description": pipeline.description, + }, + } + + cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) + + return yaml.dump(export_data, allow_unicode=True) # type: ignore + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param pipeline: Pipeline instance + """ + + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() + ) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + workflow_dict = workflow.to_dict(include_secret=include_secret) + for node in workflow_dict.get("graph", {}).get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) + for dataset_id in dataset_ids + ] + export_data["workflow"] = workflow_dict + dependencies = cls._extract_dependencies_from_workflow(workflow) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: + """ + Extract dependencies from workflow + :param workflow: Workflow instance + :return: dependencies list format like ["langgenius/google"] + """ + graph = workflow.graph_dict + dependencies = cls._extract_dependencies_from_workflow_graph(graph) + return dependencies + + @classmethod + def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]: + """ + Extract dependencies from workflow graph + :param graph: Workflow graph + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + for node in graph.get("nodes", []): + try: + typ = node.get("data", {}).get("type") + match typ: + case NodeType.TOOL.value: + tool_entity = ToolNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), + ) + case NodeType.LLM.value: + llm_entity = LLMNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), + ) + case NodeType.QUESTION_CLASSIFIER.value: + question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + question_classifier_entity.model.provider + ), + ) + case NodeType.PARAMETER_EXTRACTOR.value: + parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + parameter_extractor_entity.model.provider + ), + ) + case NodeType.KNOWLEDGE_RETRIEVAL.value: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + if knowledge_retrieval_entity.retrieval_mode == "multiple": + if knowledge_retrieval_entity.multiple_retrieval_config: + if ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "reranking_model" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider + ), + ) + elif ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "weighted_score" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.weights: + vector_setting = ( + knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting + ) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + vector_setting.embedding_provider_name + ), + ) + elif knowledge_retrieval_entity.retrieval_mode == "single": + model_config = knowledge_retrieval_entity.single_retrieval_config + if model_config: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + model_config.model.provider + ), + ) + case _: + # TODO: Handle default case or unknown node types + pass + except Exception as e: + logger.exception("Error extracting node dependency", exc_info=e) + + return dependencies + + @classmethod + def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]: + """ + Extract dependencies from model config + :param model_config: model config dict + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + + try: + # completion model + model_dict = model_config.get("model", {}) + if model_dict: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", "")) + ) + + # reranking model + dataset_configs = model_config.get("dataset_configs", {}) + if dataset_configs: + for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []): + if dataset_config.get("reranking_model"): + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + dataset_config.get("reranking_model", {}) + .get("reranking_provider_name", {}) + .get("provider") + ) + ) + + # tools + agent_configs = model_config.get("agent_mode", {}) + if agent_configs: + for agent_config in agent_configs.get("tools", []): + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id")) + ) + + except Exception as e: + logger.exception("Error extracting model config dependency", exc_info=e) + + return dependencies + + @classmethod + def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + """ + Returns the leaked dependencies in current workspace + """ + dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + if not dependencies: + return [] + + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + + @staticmethod + def _generate_aes_key(tenant_id: str) -> bytes: + """Generate AES key based on tenant_id""" + return hashlib.sha256(tenant_id.encode()).digest() + + @classmethod + def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: + """Encrypt dataset_id using AES-CBC mode""" + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) + return base64.b64encode(ct_bytes).decode() + + @classmethod + def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: + """AES decryption""" + try: + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) + return pt.decode() + except Exception: + return None + + @staticmethod + def create_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): + raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") + + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + account = cast(Account, current_user) + rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( + account=account, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=rag_pipeline_dataset_create_entity.yaml_content, + dataset=None, + ) + return { + "id": rag_pipeline_import_info.id, + "dataset_id": rag_pipeline_import_info.dataset_id, + "pipeline_id": rag_pipeline_import_info.pipeline_id, + "status": rag_pipeline_import_info.status, + "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, + "current_dsl_version": rag_pipeline_import_info.current_dsl_version, + "error": rag_pipeline_import_info.error, + } diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py new file mode 100644 index 0000000000..0908d30c12 --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -0,0 +1,23 @@ +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity +from core.plugin.impl.datasource import PluginDatasourceManager +from services.datasource_provider_service import DatasourceProviderService + + +class RagPipelineManageService: + @staticmethod + def list_rag_pipeline_datasources(tenant_id: str) -> list[PluginDatasourceProviderEntity]: + """ + list rag pipeline datasources + """ + + # get all builtin providers + manager = PluginDatasourceManager() + datasources = manager.fetch_datasource_providers(tenant_id) + for datasource in datasources: + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_datasource_credentials( + tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id + ) + if credentials: + datasource.is_authorized = True + return datasources diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..282728153a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -56,7 +57,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]): + def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): """ repack provider @@ -77,6 +78,18 @@ class ToolTransformService: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) + elif isinstance(provider, PluginDatasourceProviderEntity): + if provider.plugin_id: + if isinstance(provider.declaration.identity.icon, str): + provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.declaration.identity.icon + ) + else: + provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, + provider_name=provider.name, + icon=provider.declaration.identity.icon, + ) @classmethod def builtin_provider_to_user_provider( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2be57fd51c..39bb836172 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -257,7 +257,6 @@ class WorkflowService: type=draft_workflow.type, version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, - features=draft_workflow.features, created_by=account.id, environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables, diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py new file mode 100644 index 0000000000..dc266aef65 --- /dev/null +++ b/api/tasks/deal_dataset_index_update_task.py @@ -0,0 +1,171 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def deal_dataset_index_update_task(dataset_id: str, action: str): + """ + Async deal dataset from index + :param dataset_id: dataset_id + :param action: action + Usage: deal_dataset_index_update_task.delay(dataset_id, action) + """ + logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "upgrade": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + documents.append(document) + # save vector index + # clean keywords + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info( + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Deal dataset vector index failed") + finally: + db.session.close() diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 0b1885755b..4761e73178 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -71,7 +71,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.1.3-local + image: langgenius/dify-plugin-daemon:deploy-dev-local restart: always env_file: - ./middleware.env @@ -94,7 +94,6 @@ services: PLUGIN_REMOTE_INSTALLING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_REMOTE_INSTALLING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_WORKING_PATH: ${PLUGIN_WORKING_PATH:-/app/storage/cwd} - FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} @@ -126,6 +125,9 @@ services: VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} + THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true + THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem + FORCE_VERIFYING_SIGNATURE: false ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index a0848d98fa..8631eb58e3 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -83,7 +83,6 @@ import Confirm from '@/app/components/base/confirm' import DatasetsDetailProvider from './datasets-detail-store/provider' import { HooksStoreContextProvider } from './hooks-store' import type { Shape as HooksStoreShape } from './hooks-store' -import PluginDependency from './plugin-dependency' const nodeTypes = { [CUSTOM_NODE]: CustomNode, @@ -182,6 +181,7 @@ export const Workflow: FC = memo(({ setAutoFreeze(true) } }, []) + useEffect(() => { return () => { handleSyncWorkflowDraft(true, true) @@ -323,7 +323,6 @@ export const Workflow: FC = memo(({ ) } - {children}