From 9987774471132237fe723ef179f6d6d2acbd7815 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 10 Apr 2025 18:00:22 +0800 Subject: [PATCH 001/155] r2 --- api/controllers/console/datasets/pipeline.py | 42 ++++++++++++++++++++ api/models/dataset.py | 39 ++++++++++++++++++ api/services/rag_pipeline/rag_pipeline.py | 20 ++++++++++ 3 files changed, 101 insertions(+) create mode 100644 api/controllers/console/datasets/pipeline.py create mode 100644 api/services/rag_pipeline/rag_pipeline.py diff --git a/api/controllers/console/datasets/pipeline.py b/api/controllers/console/datasets/pipeline.py new file mode 100644 index 0000000000..20a3df8a1b --- /dev/null +++ b/api/controllers/console/datasets/pipeline.py @@ -0,0 +1,42 @@ +from flask import request +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal # type: ignore + +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) +from core.model_runtime.entities.model_entities import ModelType +from core.plugin.entities.plugin import ModelProviderID +from core.provider_manager import ProviderManager +from fields.dataset_fields import dataset_detail_fields +from libs.login import login_required +from services.dataset_service import DatasetPermissionService, DatasetService + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class PipelineTemplateListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self): + type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + # get pipeline templates + return response, 200 + + +api.add_resource(PipelineTemplateListApi, "/rag/pipeline/templates") diff --git a/api/models/dataset.py b/api/models/dataset.py index d6708ac88b..1e274a31f8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1138,3 +1138,42 @@ class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] document_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by = db.Column(StringUUID, nullable=False) + +class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] + __tablename__ = "pipeline_built_in_templates" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False) + icon = db.Column(db.JSON, nullable=False) + copyright = db.Column(db.String(255), nullable=False) + privacy_policy = db.Column(db.String(255), nullable=False) + position = db.Column(db.Integer, nullable=False) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] + __tablename__ = "pipeline_customized_templates" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), + db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False) + icon = db.Column(db.JSON, nullable=False) + position = db.Column(db.Integer, nullable=False) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py new file mode 100644 index 0000000000..c6d1769679 --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -0,0 +1,20 @@ +import datetime +import hashlib +import os +import uuid +from typing import Any, List, Literal, Union + +from flask_login import current_user + +from models.dataset import PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore + + +class RagPipelineService: + @staticmethod + def get_pipeline_templates( + type: Literal["built-in", "customized"] = "built-in", + ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: + if type == "built-in": + return PipelineBuiltInTemplate.query.all() + else: + return PipelineCustomizedTemplate.query.all() From 3340775052e930777718f38284b6bde9adc4da27 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 14 Apr 2025 11:10:44 +0800 Subject: [PATCH 002/155] r2 --- api/controllers/console/datasets/pipeline.py | 1 + api/models/dataset.py | 24 +++- .../pipeline_template/__init__.py | 0 .../pipeline_template/customized/__init__.py | 0 .../customized/customized_retrieval.py | 105 ++++++++++++++++++ .../pipeline_template/database/__init__.py | 0 .../database/database_retrieval.py | 105 ++++++++++++++++++ .../pipeline_template_base.py | 17 +++ .../pipeline_template_factory.py | 21 ++++ .../pipeline_template_type.py | 7 ++ .../pipeline_template/remote/__init__.py | 0 .../remote/remote_retrieval.py | 70 ++++++++++++ api/services/rag_pipeline/rag_pipeline.py | 34 +++++- 13 files changed, 381 insertions(+), 3 deletions(-) create mode 100644 api/services/rag_pipeline/pipeline_template/__init__.py create mode 100644 api/services/rag_pipeline/pipeline_template/customized/__init__.py create mode 100644 api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py create mode 100644 api/services/rag_pipeline/pipeline_template/database/__init__.py create mode 100644 api/services/rag_pipeline/pipeline_template/database/database_retrieval.py create mode 100644 api/services/rag_pipeline/pipeline_template/pipeline_template_base.py create mode 100644 api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py create mode 100644 api/services/rag_pipeline/pipeline_template/pipeline_template_type.py create mode 100644 api/services/rag_pipeline/pipeline_template/remote/__init__.py create mode 100644 api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py diff --git a/api/controllers/console/datasets/pipeline.py b/api/controllers/console/datasets/pipeline.py index 20a3df8a1b..5826fd8ab6 100644 --- a/api/controllers/console/datasets/pipeline.py +++ b/api/controllers/console/datasets/pipeline.py @@ -35,6 +35,7 @@ class PipelineTemplateListApi(Resource): @enterprise_license_required def get(self): type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + language = request.args.get("language", default="en-US", type=str) # get pipeline templates return response, 200 diff --git a/api/models/dataset.py b/api/models/dataset.py index 1e274a31f8..fb4908f936 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1146,7 +1146,7 @@ class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) + pipeline_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) icon = db.Column(db.JSON, nullable=False) @@ -1168,7 +1168,7 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) + pipeline_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) icon = db.Column(db.JSON, nullable=False) @@ -1177,3 +1177,23 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] language = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class Pipeline(db.Model): # type: ignore[name-defined] + __tablename__ = "pipelines" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pipeline_pkey"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) + mode = db.Column(db.String(255), nullable=False) + workflow_id = db.Column(StringUUID, nullable=True) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/rag_pipeline/pipeline_template/__init__.py b/api/services/rag_pipeline/pipeline_template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/customized/__init__.py b/api/services/rag_pipeline/pipeline_template/customized/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py new file mode 100644 index 0000000000..fccf09ef0a --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -0,0 +1,105 @@ +from typing import Optional + +from constants.languages import languages +from extensions.ext_database import db +from models.model import App, RecommendedApp +from services.app_dsl_service import AppDslService +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType + + +class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval recommended app from database + """ + + def get_pipeline_templates(self, language: str) -> dict: + result = self.fetch_pipeline_templates_from_db(language) + return result + + def get_pipeline_template_detail(self, pipeline_id: str): + result = self.fetch_pipeline_template_detail_from_db(pipeline_id) + return result + + def get_type(self) -> str: + return PipelineTemplateType.CUSTOMIZED + + @classmethod + def fetch_recommended_apps_from_db(cls, language: str) -> dict: + """ + Fetch recommended apps from db. + :param language: language + :return: + """ + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) + + if len(recommended_apps) == 0: + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) + + categories = set() + recommended_apps_result = [] + for recommended_app in recommended_apps: + app = recommended_app.app + if not app or not app.is_public: + continue + + site = app.site + if not site: + continue + + recommended_app_result = { + "id": recommended_app.id, + "app": recommended_app.app, + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) + + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + + @classmethod + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from db. + :param app_id: App ID + :return: + """ + # is in public recommended list + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) + + if not recommended_app: + return None + + # get app detail + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: + return None + + return { + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), + } diff --git a/api/services/rag_pipeline/pipeline_template/database/__init__.py b/api/services/rag_pipeline/pipeline_template/database/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py new file mode 100644 index 0000000000..3158e9f91c --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -0,0 +1,105 @@ +from typing import Optional + +from constants.languages import languages +from extensions.ext_database import db +from models.model import App, RecommendedApp +from services.app_dsl_service import AppDslService +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from database + """ + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_db(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_db(app_id) + return result + + def get_type(self) -> str: + return RecommendAppType.DATABASE + + @classmethod + def fetch_recommended_apps_from_db(cls, language: str) -> dict: + """ + Fetch recommended apps from db. + :param language: language + :return: + """ + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) + + if len(recommended_apps) == 0: + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) + + categories = set() + recommended_apps_result = [] + for recommended_app in recommended_apps: + app = recommended_app.app + if not app or not app.is_public: + continue + + site = app.site + if not site: + continue + + recommended_app_result = { + "id": recommended_app.id, + "app": recommended_app.app, + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) + + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + + @classmethod + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from db. + :param app_id: App ID + :return: + """ + # is in public recommended list + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) + + if not recommended_app: + return None + + # get app detail + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: + return None + + return { + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), + } diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py new file mode 100644 index 0000000000..45860f9d74 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + + +class PipelineTemplateRetrievalBase(ABC): + """Interface for pipeline template retrieval.""" + + @abstractmethod + def get_pipeline_templates(self, language: str) -> dict: + raise NotImplementedError + + @abstractmethod + def get_pipeline_template_detail(self, pipeline_id: str): + raise NotImplementedError + + @abstractmethod + def get_type(self) -> str: + raise NotImplementedError diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py new file mode 100644 index 0000000000..0348387119 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -0,0 +1,21 @@ +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType + + + +class RecommendAppRetrievalFactory: + @staticmethod + def get_pipeline_template_factory(mode: str) -> type[PipelineTemplateRetrievalBase]: + match mode: + case PipelineTemplateType.REMOTE: + return RemotePipelineTemplateRetrieval + case PipelineTemplateType.CUSTOMIZED: + return DatabasePipelineTemplateRetrieval + case PipelineTemplateType.BUILTIN: + return BuildInPipelineTemplateRetrieval + case _: + raise ValueError(f"invalid fetch recommended apps mode: {mode}") + + @staticmethod + def get_buildin_recommend_app_retrieval(): + return BuildInRecommendAppRetrieval diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py new file mode 100644 index 0000000000..98bc109a67 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py @@ -0,0 +1,7 @@ +from enum import StrEnum + + +class PipelineTemplateType(StrEnum): + REMOTE = "remote" + BUILTIN = "builtin" + CUSTOMIZED = "customized" diff --git a/api/services/rag_pipeline/pipeline_template/remote/__init__.py b/api/services/rag_pipeline/pipeline_template/remote/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py new file mode 100644 index 0000000000..3ba6b0e64c --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -0,0 +1,70 @@ +import logging +from typing import Optional + +import requests + +from configs import dify_config +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +logger = logging.getLogger(__name__) + + +class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval recommended app from dify official + """ + + def get_pipeline_template_detail(self, pipeline_id: str): + try: + result = self.fetch_pipeline_template_detail_from_dify_official(pipeline_id) + except Exception as e: + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(pipeline_id) + return result + + def get_recommended_apps_and_categories(self, language: str) -> dict: + try: + result = self.fetch_recommended_apps_from_dify_official(language) + except Exception as e: + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) + return result + + def get_type(self) -> str: + return RecommendAppType.REMOTE + + @classmethod + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from dify official. + :param app_id: App ID + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps/{app_id}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + return None + data: dict = response.json() + return data + + @classmethod + def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: + """ + Fetch recommended apps from dify official. + :param language: language + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps?language={language}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") + + result: dict = response.json() + + if "categories" in result: + result["categories"] = sorted(result["categories"]) + + return result diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index c6d1769679..c215606817 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, List, Literal, Union from flask_login import current_user from models.dataset import PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore - +from configs import dify_config class RagPipelineService: @staticmethod @@ -18,3 +18,35 @@ class RagPipelineService: return PipelineBuiltInTemplate.query.all() else: return PipelineCustomizedTemplate.query.all() + + @staticmethod + def get_pipeline_templates(cls, type: Literal["built-in", "customized"] = "built-in", language: str) -> dict: + """ + Get pipeline templates. + :param type: type + :param language: language + :return: + """ + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_recommended_apps_and_categories(language) + if not result.get("recommended_apps") and language != "en-US": + result = ( + RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin( + "en-US" + ) + ) + + return result + + @classmethod + def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: + """ + Get recommend app detail. + :param app_id: app id + :return: + """ + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result: dict = retrieval_instance.get_recommend_app_detail(app_id) + return result From 9f8e05d9f09937846b59d3a3a8832930684ffac3 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 14 Apr 2025 18:17:17 +0800 Subject: [PATCH 003/155] r2 --- .../feature/hosted_service/__init__.py | 17 + api/controllers/console/datasets/error.py | 6 + api/controllers/console/datasets/pipeline.py | 744 +++++++++++++++++- api/controllers/console/datasets/wraps.py | 41 + api/models/dataset.py | 9 +- api/models/workflow.py | 3 +- .../rag_pipeline_entities.py | 16 + .../pipeline_template/built_in/__init__.py | 0 .../built_in/built_in_retrieval.py | 64 ++ .../customized/customized_retrieval.py | 92 +-- .../database/database_retrieval.py | 88 +-- .../pipeline_template_base.py | 3 +- .../pipeline_template_factory.py | 16 +- .../pipeline_template_type.py | 3 +- .../remote/remote_retrieval.py | 36 +- api/services/rag_pipeline/rag_pipeline.py | 589 +++++++++++++- 16 files changed, 1522 insertions(+), 205 deletions(-) create mode 100644 api/controllers/console/datasets/wraps.py create mode 100644 api/services/entities/knowledge_entities/rag_pipeline_entities.py create mode 100644 api/services/rag_pipeline/pipeline_template/built_in/__init__.py create mode 100644 api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 18ef1ed45b..7633ffcf8a 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings): ) +class HostedFetchPipelineTemplateConfig(BaseSettings): + """ + Configuration for fetching pipeline templates + """ + + HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field( + description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,", + default="remote", + ) + + HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field( + description="Domain for fetching remote pipeline templates", + default="https://tmpl.dify.ai", + ) + + class HostedServiceConfig( # place the configs in alphabet order HostedAnthropicConfig, HostedAzureOpenAiConfig, HostedFetchAppTemplateConfig, + HostedFetchPipelineTemplateConfig, HostedMinmaxConfig, HostedOpenAiConfig, HostedSparkConfig, diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 2f00a84de6..9ee9cbd52a 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -101,3 +101,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException): error_code = "child_chunk_delete_index_error" description = "Delete child chunk index failed: {message}" code = 500 + + +class PipelineNotFoundError(BaseHTTPException): + error_code = "pipeline_not_found" + description = "Pipeline not found." + code = 404 \ No newline at end of file diff --git a/api/controllers/console/datasets/pipeline.py b/api/controllers/console/datasets/pipeline.py index 5826fd8ab6..72e819fa12 100644 --- a/api/controllers/console/datasets/pipeline.py +++ b/api/controllers/console/datasets/pipeline.py @@ -1,19 +1,49 @@ -from flask import request -from flask_login import current_user # type: ignore # type: ignore -from flask_restful import Resource, marshal # type: ignore +import json +import logging +from typing import cast +from flask import abort, request +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from configs import dify_config from controllers.console import api +from controllers.console.app.error import ( + ConversationCompletedError, + DraftWorkflowNotExist, + DraftWorkflowNotSync, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, setup_required, ) -from core.model_runtime.entities.model_entities import ModelType -from core.plugin.entities.plugin import ModelProviderID -from core.provider_manager import ProviderManager -from fields.dataset_fields import dataset_detail_fields -from libs.login import login_required -from services.dataset_service import DatasetPermissionService, DatasetService +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from factories import variable_factory +from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from fields.workflow_run_fields import workflow_run_node_execution_fields +from libs import helper +from libs.helper import TimestampField +from libs.login import current_user, login_required +from models import App +from models.account import Account +from models.dataset import Pipeline +from models.model import AppMode +from services.app_generate_service import AppGenerateService +from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.errors.app import WorkflowHashNotEqualError +from services.errors.llm import InvokeRateLimitError +from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + +logger = logging.getLogger(__name__) def _validate_name(name): @@ -37,7 +67,699 @@ class PipelineTemplateListApi(Resource): type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) language = request.args.get("language", default="en-US", type=str) # get pipeline templates - return response, 200 + pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) + return pipeline_templates, 200 -api.add_resource(PipelineTemplateListApi, "/rag/pipeline/templates") +class PipelineTemplateDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self, pipeline_id: str): + pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) + return pipeline_template, 200 + + +class CustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def patch(self, template_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + pipeline_template_info = PipelineTemplateInfoEntity(**args) + pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + return pipeline_template, 200 + + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def delete(self, template_id: str): + RagPipelineService.delete_customized_pipeline_template(template_id) + return 200 + + +class DraftRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_fields) + def get(self, pipeline: Pipeline): + """ + Get draft workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + # fetch draft workflow by app_model + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + if not workflow: + raise DraftWorkflowNotExist() + + # return workflow, if not found, return None (initiate graph by frontend) + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Sync draft workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: + parser = reqparse.RequestParser() + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("features", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") + args = parser.parse_args() + elif "text/plain" in content_type: + try: + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") + + if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): + raise ValueError("graph or features is not a dict") + + args = { + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), + } + except json.JSONDecodeError: + return {"message": "Invalid JSON data"}, 400 + else: + abort(415) + + if not isinstance(current_user, Account): + raise Forbidden() + + workflow_service = WorkflowService() + + try: + environment_variables_list = args.get("environment_variables") or [] + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.sync_draft_workflow( + pipeline=pipeline, + graph=args["graph"], + features=args["features"], + unique_hash=args.get("hash"), + account=current_user, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + except WorkflowHashNotEqualError: + raise DraftWorkflowNotSync() + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + +class RagPipelineDraftRunIterationNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run draft workflow iteration node + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, location="json") + args = parser.parse_args() + + try: + response = AppGenerateService.generate_single_iteration( + pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +class RagPipelineDraftRunLoopNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow loop node + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, location="json") + args = parser.parse_args() + + try: + response = AppGenerateService.generate_single_loop( + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +class DraftRagPipelineRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Run draft workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + args = parser.parse_args() + + try: + response = AppGenerateService.generate( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + return helper.compact_generate_response(response) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + + +class RagPipelineTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, task_id: str): + """ + Stop workflow task + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) + + return {"result": "success"} + + +class RagPipelineNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow node + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + + workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( + app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return workflow_node_execution + + +class PublishedRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_fields) + def get(self, pipeline: Pipeline): + """ + Get published pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + # fetch published workflow by pipeline + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) + + # return workflow, if not found, return None + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Publish workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("marked_name", type=str, required=False, default="", location="json") + parser.add_argument("marked_comment", type=str, required=False, default="", location="json") + args = parser.parse_args() + + # Validate name and comment length + if args.marked_name and len(args.marked_name) > 20: + raise ValueError("Marked name cannot exceed 20 characters") + if args.marked_comment and len(args.marked_comment) > 100: + raise ValueError("Marked comment cannot exceed 100 characters") + + rag_pipeline_service = RagPipelineService() + with Session(db.engine) as session: + workflow = rag_pipeline_service.publish_workflow( + session=session, + pipeline=pipeline, + account=current_user, + marked_name=args.marked_name or "", + marked_comment=args.marked_comment or "", + ) + + pipeline.workflow_id = workflow.id + db.session.commit() + + workflow_created_at = TimestampField().format(workflow.created_at) + + session.commit() + + return { + "result": "success", + "created_at": workflow_created_at, + } + + +class DefaultRagPipelineBlockConfigsApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get default block config + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + # Get default block configs + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.get_default_block_configs() + + +class DefaultRagPipelineBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline, block_type: str): + """ + Get default block config + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("q", type=str, location="args") + args = parser.parse_args() + + q = args.get("q") + + filters = None + if q: + try: + filters = json.loads(args.get("q", "")) + except json.JSONDecodeError: + raise ValueError("Invalid filters") + + # Get default block configs + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) + + +class ConvertToRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Convert basic mode of chatbot app to workflow mode + Convert expert mode of chatbot app to workflow mode + Convert Completion App to Workflow App + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + if request.data: + parser = reqparse.RequestParser() + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") + args = parser.parse_args() + else: + args = {} + + # convert to workflow mode + rag_pipeline_service = RagPipelineService() + new_app_model = rag_pipeline_service.convert_to_workflow(pipeline=pipeline, account=current_user, args=args) + + # return app id + return { + "new_app_id": new_app_model.id, + } + + +class RagPipelineConfigApi(Resource): + """Resource for rag pipeline configuration.""" + + @setup_required + @login_required + @account_initialization_required + def get(self): + return { + "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + } + + +class PublishedAllRagPipelineApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_pagination_fields) + def get(self, pipeline: Pipeline): + """ + Get published workflows + """ + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("user_id", type=str, required=False, location="args") + parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") + args = parser.parse_args() + page = int(args.get("page", 1)) + limit = int(args.get("limit", 10)) + user_id = args.get("user_id") + named_only = args.get("named_only", False) + + if user_id: + if user_id != current_user.id: + raise Forbidden() + user_id = cast(str, user_id) + + rag_pipeline_service = RagPipelineService() + with Session(db.engine) as session: + workflows, has_more = rag_pipeline_service.get_all_published_workflow( + session=session, + pipeline=pipeline, + page=page, + limit=limit, + user_id=user_id, + named_only=named_only, + ) + + return { + "items": workflows, + "page": page, + "limit": limit, + "has_more": has_more, + } + + +class RagPipelineByIdApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_fields) + def patch(self, pipeline: Pipeline, workflow_id: str): + """ + Update workflow attributes + """ + # Check permission + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("marked_name", type=str, required=False, location="json") + parser.add_argument("marked_comment", type=str, required=False, location="json") + args = parser.parse_args() + + # Validate name and comment length + if args.marked_name and len(args.marked_name) > 20: + raise ValueError("Marked name cannot exceed 20 characters") + if args.marked_comment and len(args.marked_comment) > 100: + raise ValueError("Marked comment cannot exceed 100 characters") + args = parser.parse_args() + + # Prepare update data + update_data = {} + if args.get("marked_name") is not None: + update_data["marked_name"] = args["marked_name"] + if args.get("marked_comment") is not None: + update_data["marked_comment"] = args["marked_comment"] + + if not update_data: + return {"message": "No valid fields to update"}, 400 + + rag_pipeline_service = RagPipelineService() + + # Create a session and manage the transaction + with Session(db.engine, expire_on_commit=False) as session: + workflow = rag_pipeline_service.update_workflow( + session=session, + workflow_id=workflow_id, + tenant_id=pipeline.tenant_id, + account_id=current_user.id, + data=update_data, + ) + + if not workflow: + raise NotFound("Workflow not found") + + # Commit the transaction in the controller + session.commit() + + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def delete(self, pipeline: Pipeline, workflow_id: str): + """ + Delete workflow + """ + # Check permission + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + rag_pipeline_service = RagPipelineService() + + # Create a session and manage the transaction + with Session(db.engine) as session: + try: + rag_pipeline_service.delete_workflow( + session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id + ) + # Commit the transaction in the controller + session.commit() + except WorkflowInUseError as e: + abort(400, description=str(e)) + except DraftWorkflowDeletionError as e: + abort(400, description=str(e)) + except ValueError as e: + raise NotFound(str(e)) + + return None, 204 + + +api.add_resource( + DraftRagPipelineApi, + "/rag/pipelines//workflows/draft", +) +api.add_resource( + RagPipelineConfigApi, + "/rag/pipelines//workflows/draft/config", +) +api.add_resource( + DraftRagPipelineRunApi, + "/rag/pipelines//workflows/draft/run", +) +api.add_resource( + RagPipelineTaskStopApi, + "/rag/pipelines//workflow-runs/tasks//stop", +) +api.add_resource( + RagPipelineNodeRunApi, + "/rag/pipelines//workflows/draft/nodes//run", +) + +api.add_resource( + RagPipelineDraftRunIterationNodeApi, + "/rag/pipelines//workflows/draft/iteration/nodes//run", +) +api.add_resource( + RagPipelineDraftRunLoopNodeApi, + "/rag/pipelines//workflows/draft/loop/nodes//run", +) + +api.add_resource( + PublishedRagPipelineApi, + "/rag/pipelines//workflows/publish", +) +api.add_resource( + PublishedAllRagPipelineApi, + "/rag/pipelines//workflows", +) +api.add_resource( + DefaultRagPipelineBlockConfigsApi, + "/rag/pipelines//workflows/default-workflow-block-configs", +) +api.add_resource( + DefaultRagPipelineBlockConfigApi, + "/rag/pipelines//workflows/default-workflow-block-configs/", +) +api.add_resource( + ConvertToRagPipelineApi, + "/rag/pipelines//convert-to-workflow", +) +api.add_resource( + RagPipelineByIdApi, + "/rag/pipelines//workflows/", +) + +api.add_resource( + PipelineTemplateListApi, + "/rag/pipeline/templates", +) +api.add_resource( + PipelineTemplateDetailApi, + "/rag/pipeline/templates/", +) +api.add_resource( + CustomizedPipelineTemplateApi, + "/rag/pipeline/templates/", +) diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py new file mode 100644 index 0000000000..aa8ae86860 --- /dev/null +++ b/api/controllers/console/datasets/wraps.py @@ -0,0 +1,41 @@ +from collections.abc import Callable +from functools import wraps +from typing import Optional + +from controllers.console.datasets.error import PipelineNotFoundError +from extensions.ext_database import db +from libs.login import current_user +from models.dataset import Pipeline + + +def get_rag_pipeline(view: Optional[Callable] = None,): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): + if not kwargs.get("pipeline_id"): + raise ValueError("missing pipeline_id in path parameters") + + pipeline_id = kwargs.get("pipeline_id") + pipeline_id = str(pipeline_id) + + del kwargs["pipeline_id"] + + pipeline = ( + db.session.query(Pipeline) + .filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) + .first() + ) + + if not pipeline: + raise PipelineNotFoundError() + + kwargs["pipeline"] = pipeline + + return view_func(*args, **kwargs) + + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/models/dataset.py b/api/models/dataset.py index fb4908f936..a344ab2964 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1139,11 +1139,10 @@ class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by = db.Column(StringUUID, nullable=False) + class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] __tablename__ = "pipeline_built_in_templates" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) pipeline_id = db.Column(StringUUID, nullable=False) @@ -1181,9 +1180,7 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] class Pipeline(db.Model): # type: ignore[name-defined] __tablename__ = "pipelines" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="pipeline_pkey"), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) diff --git a/api/models/workflow.py b/api/models/workflow.py index dbcb859823..c85f335f37 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -38,7 +38,8 @@ class WorkflowType(Enum): WORKFLOW = "workflow" CHAT = "chat" - + RAG_PIPELINE = "rag_pipeline" + @classmethod def value_of(cls, value: str) -> "WorkflowType": """ diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py new file mode 100644 index 0000000000..d59d47bbce --- /dev/null +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -0,0 +1,16 @@ +from typing import Optional + +from pydantic import BaseModel + + +class IconInfo(BaseModel): + icon: str + icon_background: Optional[str] = None + icon_type: Optional[str] = None + icon_url: Optional[str] = None + + +class PipelineTemplateInfoEntity(BaseModel): + name: str + description: str + icon_info: IconInfo diff --git a/api/services/rag_pipeline/pipeline_template/built_in/__init__.py b/api/services/rag_pipeline/pipeline_template/built_in/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py new file mode 100644 index 0000000000..70c72014f2 --- /dev/null +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -0,0 +1,64 @@ +import json +from os import path +from pathlib import Path +from typing import Optional + +from flask import current_app + +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType + + +class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + """ + Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json + """ + + builtin_data: Optional[dict] = None + + def get_type(self) -> str: + return PipelineTemplateType.BUILTIN + + def get_pipeline_templates(self, language: str) -> dict: + result = self.fetch_pipeline_templates_from_builtin(language) + return result + + def get_pipeline_template_detail(self, pipeline_id: str): + result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id) + return result + + @classmethod + def _get_builtin_data(cls) -> dict: + """ + Get builtin data. + :return: + """ + if cls.builtin_data: + return cls.builtin_data + + root_path = current_app.root_path + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "pipeline_templates.json")).read_text(encoding="utf-8") + ) + + return cls.builtin_data or {} + + @classmethod + def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict: + """ + Fetch pipeline templates from builtin. + :param language: language + :return: + """ + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("pipeline_templates", {}).get(language, {}) + + @classmethod + def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]: + """ + Fetch pipeline template detail from builtin. + :param pipeline_id: Pipeline ID + :return: + """ + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("pipeline_templates", {}).get(pipeline_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index fccf09ef0a..de69373ba4 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,8 +1,9 @@ from typing import Optional -from constants.languages import languages +from flask_login import current_user + from extensions.ext_database import db -from models.model import App, RecommendedApp +from models.dataset import Pipeline, PipelineCustomizedTemplate from services.app_dsl_service import AppDslService from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -14,92 +15,57 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ def get_pipeline_templates(self, language: str) -> dict: - result = self.fetch_pipeline_templates_from_db(language) + result = self.fetch_pipeline_templates_from_customized( + tenant_id=current_user.current_tenant_id, language=language + ) return result - def get_pipeline_template_detail(self, pipeline_id: str): - result = self.fetch_pipeline_template_detail_from_db(pipeline_id) + def get_pipeline_template_detail(self, template_id: str): + result = self.fetch_pipeline_template_detail_from_db(template_id) return result def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @classmethod - def fetch_recommended_apps_from_db(cls, language: str) -> dict: + def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict: """ - Fetch recommended apps from db. + Fetch pipeline templates from db. + :param tenant_id: tenant id :param language: language :return: """ - recommended_apps = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + pipeline_templates = ( + db.session.query(PipelineCustomizedTemplate) + .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .all() ) - if len(recommended_apps) == 0: - recommended_apps = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) - .all() - ) - - categories = set() - recommended_apps_result = [] - for recommended_app in recommended_apps: - app = recommended_app.app - if not app or not app.is_public: - continue - - site = app.site - if not site: - continue - - recommended_app_result = { - "id": recommended_app.id, - "app": recommended_app.app, - "app_id": recommended_app.app_id, - "description": site.description, - "copyright": site.copyright, - "privacy_policy": site.privacy_policy, - "custom_disclaimer": site.custom_disclaimer, - "category": recommended_app.category, - "position": recommended_app.position, - "is_listed": recommended_app.is_listed, - } - recommended_apps_result.append(recommended_app_result) - - categories.add(recommended_app.category) - - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + return {"pipeline_templates": pipeline_templates} @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: """ - Fetch recommended app detail from db. - :param app_id: App ID + Fetch pipeline template detail from db. + :param template_id: Template ID :return: """ - # is in public recommended list - recommended_app = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) - .first() + pipeline_template = ( + db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() ) - if not recommended_app: + if not pipeline_template: return None - # get app detail - app_model = db.session.query(App).filter(App.id == app_id).first() - if not app_model or not app_model.is_public: + # get pipeline detail + pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() + if not pipeline or not pipeline.is_public: return None return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), + "id": pipeline.id, + "name": pipeline.name, + "icon": pipeline.icon, + "mode": pipeline.mode, + "export_data": AppDslService.export_dsl(app_model=pipeline), } diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 3158e9f91c..10dd044493 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,8 +1,7 @@ from typing import Optional -from constants.languages import languages from extensions.ext_database import db -from models.model import App, RecommendedApp +from models.dataset import Pipeline, PipelineBuiltInTemplate from services.app_dsl_service import AppDslService from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType @@ -10,96 +9,57 @@ from services.recommend_app.recommend_app_type import RecommendAppType class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): """ - Retrieval recommended app from database + Retrieval pipeline template from database """ - def get_recommended_apps_and_categories(self, language: str) -> dict: - result = self.fetch_recommended_apps_from_db(language) + def get_pipeline_templates(self, language: str) -> dict: + result = self.fetch_pipeline_templates_from_db(language) return result - def get_recommend_app_detail(self, app_id: str): - result = self.fetch_recommended_app_detail_from_db(app_id) + def get_pipeline_template_detail(self, pipeline_id: str): + result = self.fetch_pipeline_template_detail_from_db(pipeline_id) return result def get_type(self) -> str: return RecommendAppType.DATABASE @classmethod - def fetch_recommended_apps_from_db(cls, language: str) -> dict: + def fetch_pipeline_templates_from_db(cls, language: str) -> dict: """ - Fetch recommended apps from db. + Fetch pipeline templates from db. :param language: language :return: """ - recommended_apps = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) - .all() + pipeline_templates = ( + db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() ) - if len(recommended_apps) == 0: - recommended_apps = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) - .all() - ) - - categories = set() - recommended_apps_result = [] - for recommended_app in recommended_apps: - app = recommended_app.app - if not app or not app.is_public: - continue - - site = app.site - if not site: - continue - - recommended_app_result = { - "id": recommended_app.id, - "app": recommended_app.app, - "app_id": recommended_app.app_id, - "description": site.description, - "copyright": site.copyright, - "privacy_policy": site.privacy_policy, - "custom_disclaimer": site.custom_disclaimer, - "category": recommended_app.category, - "position": recommended_app.position, - "is_listed": recommended_app.is_listed, - } - recommended_apps_result.append(recommended_app_result) - - categories.add(recommended_app.category) - - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + return {"pipeline_templates": pipeline_templates} @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: """ - Fetch recommended app detail from db. - :param app_id: App ID + Fetch pipeline template detail from db. + :param pipeline_id: Pipeline ID :return: """ # is in public recommended list - recommended_app = ( - db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) - .first() + pipeline_template = ( + db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() ) - if not recommended_app: + if not pipeline_template: return None # get app detail - app_model = db.session.query(App).filter(App.id == app_id).first() - if not app_model or not app_model.is_public: + pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() + if not pipeline or not pipeline.is_public: return None return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), + "id": pipeline.id, + "name": pipeline.name, + "icon": pipeline.icon, + "mode": pipeline.mode, + "export_data": AppDslService.export_dsl(app_model=pipeline), } diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index 45860f9d74..fa6a38a357 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional class PipelineTemplateRetrievalBase(ABC): @@ -9,7 +10,7 @@ class PipelineTemplateRetrievalBase(ABC): raise NotImplementedError @abstractmethod - def get_pipeline_template_detail(self, pipeline_id: str): + def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]: raise NotImplementedError @abstractmethod diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py index 0348387119..37e40bf6a0 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -1,9 +1,11 @@ +from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval +from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval - -class RecommendAppRetrievalFactory: +class PipelineTemplateRetrievalFactory: @staticmethod def get_pipeline_template_factory(mode: str) -> type[PipelineTemplateRetrievalBase]: match mode: @@ -11,11 +13,13 @@ class RecommendAppRetrievalFactory: return RemotePipelineTemplateRetrieval case PipelineTemplateType.CUSTOMIZED: return DatabasePipelineTemplateRetrieval - case PipelineTemplateType.BUILTIN: - return BuildInPipelineTemplateRetrieval + case PipelineTemplateType.DATABASE: + return DatabasePipelineTemplateRetrieval + case PipelineTemplateType.BUILT_IN: + return BuiltInPipelineTemplateRetrieval case _: raise ValueError(f"invalid fetch recommended apps mode: {mode}") @staticmethod - def get_buildin_recommend_app_retrieval(): - return BuildInRecommendAppRetrieval + def get_built_in_pipeline_template_retrieval(): + return BuiltInPipelineTemplateRetrieval diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py index 98bc109a67..e914266d26 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_type.py @@ -3,5 +3,6 @@ from enum import StrEnum class PipelineTemplateType(StrEnum): REMOTE = "remote" - BUILTIN = "builtin" + DATABASE = "database" CUSTOMIZED = "customized" + BUILTIN = "builtin" diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 3ba6b0e64c..5553d7c97e 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -4,9 +4,10 @@ from typing import Optional import requests from configs import dify_config -from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval -from services.recommend_app.recommend_app_type import RecommendAppType from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval + logger = logging.getLogger(__name__) @@ -23,26 +24,26 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(pipeline_id) return result - def get_recommended_apps_and_categories(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict: try: - result = self.fetch_recommended_apps_from_dify_official(language) + result = self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: - logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") + logger.warning(f"fetch pipeline templates from dify official failed: {e}, switch to built-in.") result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) return result def get_type(self) -> str: - return RecommendAppType.REMOTE + return PipelineTemplateType.REMOTE @classmethod - def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_dify_official(cls, pipeline_id: str) -> Optional[dict]: """ - Fetch recommended app detail from dify official. - :param app_id: App ID + Fetch pipeline template detail from dify official. + :param pipeline_id: Pipeline ID :return: """ - domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f"{domain}/apps/{app_id}" + domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/pipelines/{pipeline_id}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None @@ -50,21 +51,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return data @classmethod - def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: + def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict: """ - Fetch recommended apps from dify official. + Fetch pipeline templates from dify official. :param language: language :return: """ - domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f"{domain}/apps?language={language}" + domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/pipelines?language={language}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: - raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") + raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}") result: dict = response.json() - if "categories" in result: - result["categories"] = sorted(result["categories"]) - return result diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index c215606817..c2c9c56e9d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,52 +1,575 @@ -import datetime -import hashlib -import os -import uuid -from typing import Any, List, Literal, Union +import json +import time +from collections.abc import Callable, Generator, Sequence +from datetime import UTC, datetime +from typing import Any, Literal, Optional from flask_login import current_user +from sqlalchemy import select +from sqlalchemy.orm import Session -from models.dataset import PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore from configs import dify_config +from core.variables.variables import Variable +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.types import NodeEvent +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.workflow_entry import WorkflowEntry +from extensions.db import db +from models.account import Account +from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType +from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.errors.app import WorkflowHashNotEqualError +from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory + class RagPipelineService: @staticmethod def get_pipeline_templates( - type: Literal["built-in", "customized"] = "built-in", + type: Literal["built-in", "customized"] = "built-in", language: str = "en-US" ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: if type == "built-in": - return PipelineBuiltInTemplate.query.all() - else: - return PipelineCustomizedTemplate.query.all() - - @staticmethod - def get_pipeline_templates(cls, type: Literal["built-in", "customized"] = "built-in", language: str) -> dict: - """ - Get pipeline templates. - :param type: type - :param language: language - :return: - """ - mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result = retrieval_instance.get_recommended_apps_and_categories(language) - if not result.get("recommended_apps") and language != "en-US": - result = ( - RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin( + mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + result = retrieval_instance.get_pipeline_templates(language) + if not result.get("pipeline_templates") and language != "en-US": + result = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval().fetch_pipeline_templates_from_builtin( "en-US" ) - ) + return result.get("pipeline_templates") + else: + mode = "customized" + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + result = retrieval_instance.get_pipeline_templates(language) + return result.get("pipeline_templates") + @classmethod + def get_pipeline_template_detail(cls, pipeline_id: str) -> Optional[dict]: + """ + Get pipeline template detail. + :param pipeline_id: pipeline id + :return: + """ + mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(pipeline_id) return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: + def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity): """ - Get recommend app detail. - :param app_id: app id + Update pipeline template. + :param template_id: template id + :param template_info: template info + """ + customized_template: PipelineCustomizedTemplate | None = ( + db.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.id == template_id, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + ) + .first() + ) + if not customized_template: + raise ValueError("Customized pipeline template not found.") + customized_template.name = template_info.name + customized_template.description = template_info.description + customized_template.icon = template_info.icon_info.model_dump() + db.commit() + return customized_template + + @classmethod + def delete_customized_pipeline_template(cls, template_id: str): + """ + Delete customized pipeline template. + """ + customized_template: PipelineCustomizedTemplate | None = ( + db.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.id == template_id, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + ) + .first() + ) + if not customized_template: + raise ValueError("Customized pipeline template not found.") + db.delete(customized_template) + db.commit() + + + def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + """ + Get draft workflow + """ + # fetch draft workflow by rag pipeline + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() + ) + + # return draft workflow + return workflow + + def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + """ + Get published workflow + """ + + if not pipeline.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.id == pipeline.workflow_id, + ) + .first() + ) + + return workflow + + def get_all_published_workflow( + self, + *, + session: Session, + pipeline: Pipeline, + page: int, + limit: int, + user_id: str | None, + named_only: bool = False, + ) -> tuple[Sequence[Workflow], bool]: + """ + Get published workflow with pagination + """ + if not pipeline.workflow_id: + return [], False + + stmt = ( + select(Workflow) + .where(Workflow.app_id == pipeline.id) + .order_by(Workflow.version.desc()) + .limit(limit + 1) + .offset((page - 1) * limit) + ) + + if user_id: + stmt = stmt.where(Workflow.created_by == user_id) + + if named_only: + stmt = stmt.where(Workflow.marked_name != "") + + workflows = session.scalars(stmt).all() + + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + + def sync_draft_workflow( + self, + *, + pipeline: Pipeline, + graph: dict, + features: dict, + unique_hash: Optional[str], + account: Account, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ) -> Workflow: + """ + Sync draft workflow + :raises WorkflowHashNotEqualError + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(pipeline=pipeline) + + if workflow and workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + + # validate features structure + self.validate_features_structure(pipeline=pipeline, features=features) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + type=WorkflowType.RAG_PIPELINE.value, + version="draft", + graph=json.dumps(graph), + features=json.dumps(features), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + db.session.add(workflow) + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + + # commit db session changes + db.session.commit() + + # trigger app workflow events + app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow) + + # return draft workflow + return workflow + + def publish_workflow( + self, + *, + session: Session, + pipeline: Pipeline, + account: Account, + marked_name: str = "", + marked_comment: str = "", + ) -> Workflow: + draft_workflow_stmt = select(Workflow).where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + draft_workflow = session.scalar(draft_workflow_stmt) + if not draft_workflow: + raise ValueError("No valid workflow found.") + + # create new workflow + workflow = Workflow.new( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + type=draft_workflow.type, + version=str(datetime.now(UTC).replace(tzinfo=None)), + graph=draft_workflow.graph, + features=draft_workflow.features, + created_by=account.id, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + marked_name=marked_name, + marked_comment=marked_comment, + ) + + # commit db session changes + session.add(workflow) + + # trigger app workflow events + app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) + + # return new workflow + return workflow + + def get_default_block_configs(self) -> list[dict]: + """ + Get default block configs + """ + # return default block config + default_block_configs = [] + for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): + node_class = node_class_mapping[LATEST_VERSION] + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs + + def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. :return: """ - mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result: dict = retrieval_instance.get_recommend_app_detail(app_id) - return result + node_type_enum = NodeType(node_type) + + # return default block config + if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + return None + + node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config + + def run_draft_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ), + start_at=start_at, + tenant_id=pipeline.tenant_id, + node_id=node_id, + ) + + workflow_node_execution.app_id = pipeline.id + workflow_node_execution.created_by = account.id + workflow_node_execution.workflow_id = draft_workflow.id + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def run_free_workflow_node( + self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + ) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.run_free_node( + node_id=node_id, + node_data=node_data, + tenant_id=tenant_id, + user_id=user_id, + user_inputs=user_inputs, + ), + start_at=start_at, + tenant_id=tenant_id, + node_id=node_id, + ) + + return workflow_node_execution + + def _handle_node_run_result( + self, + getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], + start_at: float, + tenant_id: str, + node_id: str, + ) -> WorkflowNodeExecution: + """ + Handle node run result + + :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]] + :param start_at: float + :param tenant_id: str + :param node_id: str + """ + try: + node_instance, generator = getter() + + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args: dict[str, Any] = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) + error = node_run_result.error if not run_succeeded else None + except WorkflowNodeRunFailedError as e: + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) + workflow_node_execution.tenant_id = tenant_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + if run_succeeded and node_run_result: + # create workflow node execution + inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + process_data = ( + WorkflowEntry.handle_special_values(node_run_result.process_data) + if node_run_result.process_data + else None + ) + outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + + workflow_node_execution.inputs = json.dumps(inputs) + workflow_node_execution.process_data = json.dumps(process_data) + workflow_node_execution.outputs = json.dumps(outputs) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ) + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value + workflow_node_execution.error = node_run_result.error + else: + # create workflow node execution + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + + return workflow_node_execution + + def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: + """ + Basic mode of chatbot app(expert mode) to workflow + Completion App to Workflow App + + :param app_model: App instance + :param account: Account instance + :param args: dict + :return: + """ + # chatbot convert to workflow mode + workflow_converter = WorkflowConverter() + + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") + + # convert to workflow + new_app: App = workflow_converter.convert_to_workflow( + app_model=app_model, + account=account, + name=args.get("name", "Default Name"), + icon_type=args.get("icon_type", "emoji"), + icon=args.get("icon", "🤖"), + icon_background=args.get("icon_background", "#FFEAD5"), + ) + + return new_app + + def validate_features_structure(self, app_model: App, features: dict) -> dict: + if app_model.mode == AppMode.ADVANCED_CHAT.value: + return AdvancedChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + elif app_model.mode == AppMode.WORKFLOW.value: + return WorkflowAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + else: + raise ValueError(f"Invalid app mode: {app_model.mode}") + + def update_workflow( + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + ) -> Optional[Workflow]: + """ + Update workflow attributes + + :param session: SQLAlchemy database session + :param workflow_id: Workflow ID + :param tenant_id: Tenant ID + :param account_id: Account ID (for permission check) + :param data: Dictionary containing fields to update + :return: Updated workflow or None if not found + """ + stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + workflow = session.scalar(stmt) + + if not workflow: + return None + + allowed_fields = ["marked_name", "marked_comment"] + + for field, value in data.items(): + if field in allowed_fields: + setattr(workflow, field, value) + + workflow.updated_by = account_id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + + return workflow + + def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool: + """ + Delete a workflow + + :param session: SQLAlchemy database session + :param workflow_id: Workflow ID + :param tenant_id: Tenant ID + :return: True if successful + :raises: ValueError if workflow not found + :raises: WorkflowInUseError if workflow is in use + :raises: DraftWorkflowDeletionError if workflow is a draft version + """ + stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + workflow = session.scalar(stmt) + + if not workflow: + raise ValueError(f"Workflow with ID {workflow_id} not found") + + # Check if workflow is a draft version + if workflow.version == "draft": + raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") + + # Check if this workflow is currently referenced by an app + stmt = select(App).where(App.workflow_id == workflow_id) + app = session.scalar(stmt) + if app: + # Cannot delete a workflow that's currently in use by an app + raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") + + session.delete(workflow) + return True From 5c4bf2a9e451643a8d72d70f3443cc6c2d4240e2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 17 Apr 2025 15:07:23 +0800 Subject: [PATCH 004/155] r2 --- api/controllers/console/__init__.py | 2 +- .../datasets/rag_pipeline/rag_pipeline.py | 136 +++ .../rag_pipeline_workflow.py} | 199 ++-- api/core/datasource/__base/tool.py | 222 +++++ api/core/datasource/__base/tool_provider.py | 109 +++ api/core/datasource/__base/tool_runtime.py | 36 + api/core/datasource/__init__.py | 0 .../datasource/entities/agent_entities.py | 0 api/core/datasource/entities/api_entities.py | 72 ++ .../datasource/entities/common_entities.py | 23 + api/core/datasource/entities/constants.py | 1 + api/core/datasource/entities/file_entities.py | 1 + api/core/datasource/entities/tool_bundle.py | 29 + api/core/datasource/entities/tool_entities.py | 427 +++++++++ api/core/datasource/entities/values.py | 111 +++ api/core/datasource/errors.py | 37 + api/core/datasource/plugin_tool/provider.py | 79 ++ api/core/datasource/plugin_tool/tool.py | 89 ++ api/core/datasource/tool_engine.py | 357 +++++++ api/core/datasource/tool_file_manager.py | 234 +++++ api/core/datasource/tool_label_manager.py | 101 ++ api/core/datasource/tool_manager.py | 870 ++++++++++++++++++ api/core/datasource/utils/__init__.py | 0 api/core/datasource/utils/configuration.py | 265 ++++++ .../dataset_multi_retriever_tool.py | 199 ++++ .../dataset_retriever_base_tool.py | 33 + .../dataset_retriever_tool.py | 202 ++++ .../utils/dataset_retriever_tool.py | 134 +++ .../datasource/utils/message_transformer.py | 121 +++ .../utils/model_invocation_utils.py | 169 ++++ api/core/datasource/utils/parser.py | 389 ++++++++ api/core/datasource/utils/rag_web_reader.py | 17 + .../datasource/utils/text_processing_utils.py | 17 + api/core/datasource/utils/uuid_utils.py | 9 + api/core/datasource/utils/web_reader_tool.py | 375 ++++++++ .../utils/workflow_configuration_sync.py | 43 + api/core/datasource/utils/yaml_utils.py | 35 + api/core/workflow/constants.py | 1 + api/core/workflow/entities/node_entities.py | 1 + .../workflow/nodes/datasource/__init__.py | 3 + .../nodes/datasource/datasource_node.py | 406 ++++++++ .../workflow/nodes/datasource/entities.py | 56 ++ api/core/workflow/nodes/datasource/exc.py | 16 + api/core/workflow/nodes/enums.py | 1 + api/core/workflow/nodes/tool/tool_node.py | 2 +- api/factories/variable_factory.py | 10 +- api/fields/workflow_fields.py | 11 + api/models/workflow.py | 21 + api/services/rag_pipeline/rag_pipeline.py | 60 +- 49 files changed, 5609 insertions(+), 122 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/rag_pipeline.py rename api/controllers/console/datasets/{pipeline.py => rag_pipeline/rag_pipeline_workflow.py} (86%) create mode 100644 api/core/datasource/__base/tool.py create mode 100644 api/core/datasource/__base/tool_provider.py create mode 100644 api/core/datasource/__base/tool_runtime.py create mode 100644 api/core/datasource/__init__.py create mode 100644 api/core/datasource/entities/agent_entities.py create mode 100644 api/core/datasource/entities/api_entities.py create mode 100644 api/core/datasource/entities/common_entities.py create mode 100644 api/core/datasource/entities/constants.py create mode 100644 api/core/datasource/entities/file_entities.py create mode 100644 api/core/datasource/entities/tool_bundle.py create mode 100644 api/core/datasource/entities/tool_entities.py create mode 100644 api/core/datasource/entities/values.py create mode 100644 api/core/datasource/errors.py create mode 100644 api/core/datasource/plugin_tool/provider.py create mode 100644 api/core/datasource/plugin_tool/tool.py create mode 100644 api/core/datasource/tool_engine.py create mode 100644 api/core/datasource/tool_file_manager.py create mode 100644 api/core/datasource/tool_label_manager.py create mode 100644 api/core/datasource/tool_manager.py create mode 100644 api/core/datasource/utils/__init__.py create mode 100644 api/core/datasource/utils/configuration.py create mode 100644 api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py create mode 100644 api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py create mode 100644 api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py create mode 100644 api/core/datasource/utils/dataset_retriever_tool.py create mode 100644 api/core/datasource/utils/message_transformer.py create mode 100644 api/core/datasource/utils/model_invocation_utils.py create mode 100644 api/core/datasource/utils/parser.py create mode 100644 api/core/datasource/utils/rag_web_reader.py create mode 100644 api/core/datasource/utils/text_processing_utils.py create mode 100644 api/core/datasource/utils/uuid_utils.py create mode 100644 api/core/datasource/utils/web_reader_tool.py create mode 100644 api/core/datasource/utils/workflow_configuration_sync.py create mode 100644 api/core/datasource/utils/yaml_utils.py create mode 100644 api/core/workflow/nodes/datasource/__init__.py create mode 100644 api/core/workflow/nodes/datasource/datasource_node.py create mode 100644 api/core/workflow/nodes/datasource/entities.py create mode 100644 api/core/workflow/nodes/datasource/exc.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a974c63e35..74e5da9435 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,5 +1,6 @@ from flask import Blueprint +from .datasets.rag_pipeline import data_source from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi @@ -75,7 +76,6 @@ from .billing import billing, compliance # Import datasets controllers from .datasets import ( - data_source, datasets, datasets_document, datasets_segments, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py new file mode 100644 index 0000000000..4ff2f07bb6 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -0,0 +1,136 @@ +import json +import logging +from typing import cast + +from flask import abort, request +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from configs import dify_config +from controllers.console import api +from controllers.console.app.error import ( + ConversationCompletedError, + DraftWorkflowNotExist, + DraftWorkflowNotSync, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from factories import variable_factory +from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from fields.workflow_run_fields import workflow_run_node_execution_fields +from libs import helper +from libs.helper import TimestampField +from libs.login import current_user, login_required +from models import App +from models.account import Account +from models.dataset import Pipeline +from models.model import AppMode +from services.app_generate_service import AppGenerateService +from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.errors.app import WorkflowHashNotEqualError +from services.errors.llm import InvokeRateLimitError +from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + +logger = logging.getLogger(__name__) + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class PipelineTemplateListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self): + type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + language = request.args.get("language", default="en-US", type=str) + # get pipeline templates + pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) + return pipeline_templates, 200 + + +class PipelineTemplateDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self, pipeline_id: str): + pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) + return pipeline_template, 200 + + +class CustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def patch(self, template_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + pipeline_template_info = PipelineTemplateInfoEntity(**args) + pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + return pipeline_template, 200 + + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def delete(self, template_id: str): + RagPipelineService.delete_customized_pipeline_template(template_id) + return 200 + + +api.add_resource( + PipelineTemplateListApi, + "/rag/pipeline/templates", +) +api.add_resource( + PipelineTemplateDetailApi, + "/rag/pipeline/templates/", +) +api.add_resource( + CustomizedPipelineTemplateApi, + "/rag/pipeline/templates/", +) diff --git a/api/controllers/console/datasets/pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py similarity index 86% rename from api/controllers/console/datasets/pipeline.py rename to api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 72e819fa12..d33531b447 100644 --- a/api/controllers/console/datasets/pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -15,11 +15,9 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) -from controllers.console.app.wraps import get_app_model from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, - enterprise_license_required, setup_required, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -32,96 +30,17 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField from libs.login import current_user, login_required -from models import App from models.account import Account from models.dataset import Pipeline -from models.model import AppMode from services.app_generate_service import AppGenerateService -from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.rag_pipeline import RagPipelineService -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description): - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - -class PipelineTemplateListApi(Resource): - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def get(self): - type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) - language = request.args.get("language", default="en-US", type=str) - # get pipeline templates - pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) - return pipeline_templates, 200 - - -class PipelineTemplateDetailApi(Resource): - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def get(self, pipeline_id: str): - pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) - return pipeline_template, 200 - - -class CustomizedPipelineTemplateApi(Resource): - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def patch(self, template_id: str): - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=str, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, - ) - args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity(**args) - pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) - return pipeline_template, 200 - - @setup_required - @login_required - @account_initialization_required - @enterprise_license_required - def delete(self, template_id: str): - RagPipelineService.delete_customized_pipeline_template(template_id) - return 200 - - class DraftRagPipelineApi(Resource): @setup_required @login_required @@ -130,7 +49,7 @@ class DraftRagPipelineApi(Resource): @marshal_with(workflow_fields) def get(self, pipeline: Pipeline): """ - Get draft workflow + Get draft rag pipeline's workflow """ # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: @@ -167,6 +86,7 @@ class DraftRagPipelineApi(Resource): parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") + parser.add_argument("pipeline_variables", type=dict, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -183,6 +103,7 @@ class DraftRagPipelineApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), + "pipeline_variables": data.get("pipeline_variables"), } except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 @@ -192,8 +113,6 @@ class DraftRagPipelineApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - workflow_service = WorkflowService() - try: environment_variables_list = args.get("environment_variables") or [] environment_variables = [ @@ -203,6 +122,11 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] + pipeline_variables_list = args.get("pipeline_variables") or {} + pipeline_variables = { + k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] + for k, v in pipeline_variables_list.items() + } rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -212,6 +136,7 @@ class DraftRagPipelineApi(Resource): account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, + pipeline_variables=pipeline_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -263,8 +188,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - def post(self, app_model: App, node_id: str): + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow loop node """ @@ -281,7 +206,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): try: response = AppGenerateService.generate_single_loop( - app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True + pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -300,8 +225,8 @@ class DraftRagPipelineRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - def post(self, app_model: App): + @get_rag_pipeline + def post(self, pipeline: Pipeline): """ Run draft workflow """ @@ -319,7 +244,7 @@ class DraftRagPipelineRunApi(Resource): try: response = AppGenerateService.generate( - app_model=app_model, + pipeline=pipeline, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, @@ -330,32 +255,45 @@ class DraftRagPipelineRunApi(Resource): except InvokeRateLimitError as ex: raise InvokeRateLimitHttpError(ex.description) - -class RagPipelineTaskStopApi(Resource): +class RagPipelineDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def post(self, app_model: App, task_id: str): + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): """ - Stop workflow task + Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) + if not isinstance(current_user, Account): + raise Forbidden() - return {"result": "success"} + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return workflow_node_execution -class RagPipelineNodeRunApi(Resource): +class RagPipelineDraftNodeRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @get_rag_pipeline @marshal_with(workflow_run_node_execution_fields) - def post(self, app_model: App, node_id: str): + def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow node """ @@ -374,13 +312,29 @@ class RagPipelineNodeRunApi(Resource): if inputs == None: raise ValueError("missing inputs") - workflow_service = WorkflowService() - workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user + rag_pipeline_service = RagPipelineService() + workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user ) return workflow_node_execution +class RagPipelineTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, task_id: str): + """ + Stop workflow task + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) + + return {"result": "success"} class PublishedRagPipelineApi(Resource): @setup_required @@ -695,6 +649,25 @@ class RagPipelineByIdApi(Resource): return None, 204 +class RagPipelineSecondStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get second step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + datasource_provider = request.args.get("datasource_provider", required=True, type=str) + + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, + datasource_provider=datasource_provider + ) + api.add_resource( DraftRagPipelineApi, @@ -713,9 +686,13 @@ api.add_resource( "/rag/pipelines//workflow-runs/tasks//stop", ) api.add_resource( - RagPipelineNodeRunApi, + RagPipelineDraftNodeRunApi, "/rag/pipelines//workflows/draft/nodes//run", ) +api.add_resource( + RagPipelinePublishedNodeRunApi, + "/rag/pipelines//workflows/published/nodes//run", +) api.add_resource( RagPipelineDraftRunIterationNodeApi, @@ -751,15 +728,3 @@ api.add_resource( "/rag/pipelines//workflows/", ) -api.add_resource( - PipelineTemplateListApi, - "/rag/pipeline/templates", -) -api.add_resource( - PipelineTemplateDetailApi, - "/rag/pipeline/templates/", -) -api.add_resource( - CustomizedPipelineTemplateApi, - "/rag/pipeline/templates/", -) diff --git a/api/core/datasource/__base/tool.py b/api/core/datasource/__base/tool.py new file mode 100644 index 0000000000..35e16b5c8f --- /dev/null +++ b/api/core/datasource/__base/tool.py @@ -0,0 +1,222 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from models.model import File + +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) + + +class Tool(ABC): + """ + The base class of a tool + """ + + entity: ToolEntity + runtime: ToolRuntime + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: + self.entity = entity + self.runtime = runtime + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + """ + fork a new tool with metadata + :return: the new tool + """ + return self.__class__( + entity=self.entity.model_copy(), + runtime=runtime, + ) + + @abstractmethod + def tool_provider_type(self) -> ToolProviderType: + """ + get the tool provider type + + :return: the tool provider type + """ + + def invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage]: + if self.runtime and self.runtime.runtime_parameters: + tool_parameters.update(self.runtime.runtime_parameters) + + # try parse tool parameters into the correct type + tool_parameters = self._transform_tool_parameters_type(tool_parameters) + + result = self._invoke( + user_id=user_id, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + if isinstance(result, ToolInvokeMessage): + + def single_generator() -> Generator[ToolInvokeMessage, None, None]: + yield result + + return single_generator() + elif isinstance(result, list): + + def generator() -> Generator[ToolInvokeMessage, None, None]: + yield from result + + return generator() + else: + return result + + def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: + """ + Transform tool parameters type + """ + # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials + result = deepcopy(tool_parameters) + for parameter in self.entity.parameters or []: + if parameter.name in tool_parameters: + result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) + + return result + + @abstractmethod + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: + pass + + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get the runtime parameters + + interface for developer to dynamic change the parameters of a tool depends on the variables pool + + :return: the runtime parameters + """ + return self.entity.parameters + + def get_merged_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get merged runtime parameters + + :return: merged runtime parameters + """ + parameters = self.entity.parameters + parameters = parameters.copy() + user_parameters = self.get_runtime_parameters() or [] + user_parameters = user_parameters.copy() + + # override parameters + for parameter in user_parameters: + # check if parameter in tool parameters + for tool_parameter in parameters: + if tool_parameter.name == parameter.name: + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description + break + else: + # add new parameter + parameters.append(parameter) + + return parameters + + def create_image_message( + self, + image: str, + ) -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) + ) + + def create_file_message(self, file: "File") -> ToolInvokeMessage: + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.FILE, + message=ToolInvokeMessage.FileMessage(), + meta={"file": file}, + ) + + def create_link_message(self, link: str) -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link) + ) + + def create_text_message(self, text: str) -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=ToolInvokeMessage.TextMessage(text=text), + ) + + def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :param meta: the meta info of blob object + :return: the blob message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta=meta, + ) + + def create_json_message(self, object: dict) -> ToolInvokeMessage: + """ + create a json message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) + ) diff --git a/api/core/datasource/__base/tool_provider.py b/api/core/datasource/__base/tool_provider.py new file mode 100644 index 0000000000..d096fc7df7 --- /dev/null +++ b/api/core/datasource/__base/tool_provider.py @@ -0,0 +1,109 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any + +from core.entities.provider_entities import ProviderConfig +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolProviderEntity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError + + +class ToolProviderController(ABC): + entity: ToolProviderEntity + + def __init__(self, entity: ToolProviderEntity) -> None: + self.entity = entity + + def get_credentials_schema(self) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider + + :return: the credentials schema + """ + return deepcopy(self.entity.credentials_schema) + + @abstractmethod + def get_tool(self, tool_name: str) -> Tool: + """ + returns a tool that the provider can provide + + :return: tool + """ + pass + + @property + def provider_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.BUILT_IN + + def validate_credentials_format(self, credentials: dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = dict[str, ProviderConfig]() + if credentials_schema is None: + return + + for credential in self.entity.credentials_schema: + credentials_schema[credential.name] = credential + + credentials_need_to_validate: dict[str, ProviderConfig] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.entity.identity.name}" + ) + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if not credential_schema.required and credentials[credential_name] is None: + continue + + if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + elif credential_schema.type == ProviderConfig.Type.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type in { + ProviderConfig.Type.SECRET_INPUT, + ProviderConfig.Type.TEXT_INPUT, + ProviderConfig.Type.SELECT, + }: + default_value = str(default_value) + + credentials[credential_name] = default_value diff --git a/api/core/datasource/__base/tool_runtime.py b/api/core/datasource/__base/tool_runtime.py new file mode 100644 index 0000000000..c9e157cb77 --- /dev/null +++ b/api/core/datasource/__base/tool_runtime.py @@ -0,0 +1,36 @@ +from typing import Any, Optional + +from openai import BaseModel +from pydantic import Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.entities.tool_entities import ToolInvokeFrom + + +class ToolRuntime(BaseModel): + """ + Meta data of a tool call processing + """ + + tenant_id: str + tool_id: Optional[str] = None + invoke_from: Optional[InvokeFrom] = None + tool_invoke_from: Optional[ToolInvokeFrom] = None + credentials: dict[str, Any] = Field(default_factory=dict) + runtime_parameters: dict[str, Any] = Field(default_factory=dict) + + +class FakeToolRuntime(ToolRuntime): + """ + Fake tool runtime for testing + """ + + def __init__(self): + super().__init__( + tenant_id="fake_tenant_id", + tool_id="fake_tool_id", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credentials={}, + runtime_parameters={}, + ) diff --git a/api/core/datasource/__init__.py b/api/core/datasource/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/entities/agent_entities.py b/api/core/datasource/entities/agent_entities.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py new file mode 100644 index 0000000000..b96c994cff --- /dev/null +++ b/api/core/datasource/entities/api_entities.py @@ -0,0 +1,72 @@ +from typing import Literal, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool import ToolParameter +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType + + +class ToolApiEntity(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[list[ToolParameter]] = None + labels: list[str] = Field(default_factory=list) + output_schema: Optional[dict] = None + + +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] + + +class ToolProviderApiEntity(BaseModel): + id: str + author: str + name: str # identifier + description: I18nObject + icon: str | dict + label: I18nObject # label + type: ToolProviderType + masked_credentials: Optional[dict] = None + original_credentials: Optional[dict] = None + is_team_authorization: bool = False + allow_delete: bool = True + plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + tools: list[ToolApiEntity] = Field(default_factory=list) + labels: list[str] = Field(default_factory=list) + + @field_validator("tools", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + + def to_dict(self) -> dict: + # ------------- + # overwrite tool parameter types for temp fix + tools = jsonable_encoder(self.tools) + for tool in tools: + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: + parameter["type"] = "files" + # ------------- + + return { + "id": self.id, + "author": self.author, + "name": self.name, + "plugin_id": self.plugin_id, + "plugin_unique_identifier": self.plugin_unique_identifier, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, + } diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py new file mode 100644 index 0000000000..924e6fc0cf --- /dev/null +++ b/api/core/datasource/entities/common_entities.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class I18nObject(BaseModel): + """ + Model class for i18n object. + """ + + en_US: str + zh_Hans: Optional[str] = Field(default=None) + pt_BR: Optional[str] = Field(default=None) + ja_JP: Optional[str] = Field(default=None) + + def __init__(self, **data): + super().__init__(**data) + self.zh_Hans = self.zh_Hans or self.en_US + self.pt_BR = self.pt_BR or self.en_US + self.ja_JP = self.ja_JP or self.en_US + + def to_dict(self) -> dict: + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/datasource/entities/constants.py b/api/core/datasource/entities/constants.py new file mode 100644 index 0000000000..199c9f0d53 --- /dev/null +++ b/api/core/datasource/entities/constants.py @@ -0,0 +1 @@ +TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__" diff --git a/api/core/datasource/entities/file_entities.py b/api/core/datasource/entities/file_entities.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/core/datasource/entities/file_entities.py @@ -0,0 +1 @@ + diff --git a/api/core/datasource/entities/tool_bundle.py b/api/core/datasource/entities/tool_bundle.py new file mode 100644 index 0000000000..ffeeabbc1c --- /dev/null +++ b/api/core/datasource/entities/tool_bundle.py @@ -0,0 +1,29 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.tools.entities.tool_entities import ToolParameter + + +class ApiToolBundle(BaseModel): + """ + This class is used to store the schema information of an api based tool. + such as the url, the method, the parameters, etc. + """ + + # server_url + server_url: str + # method + method: str + # summary + summary: Optional[str] = None + # operation_id + operation_id: Optional[str] = None + # parameters + parameters: Optional[list[ToolParameter]] = None + # author + author: str + # icon + icon: Optional[str] = None + # openapi operation + openapi: dict diff --git a/api/core/datasource/entities/tool_entities.py b/api/core/datasource/entities/tool_entities.py new file mode 100644 index 0000000000..d756763137 --- /dev/null +++ b/api/core/datasource/entities/tool_entities.py @@ -0,0 +1,427 @@ +import base64 +import enum +from collections.abc import Mapping +from enum import Enum +from typing import Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator + +from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY + + +class ToolLabelEnum(Enum): + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + + +class ToolProviderType(enum.StrEnum): + """ + Enum class for tool provider + """ + + PLUGIN = "plugin" + BUILT_IN = "builtin" + WORKFLOW = "workflow" + API = "api" + APP = "app" + DATASET_RETRIEVAL = "dataset-retrieval" + + @classmethod + def value_of(cls, value: str) -> "ToolProviderType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class ApiProviderSchemaType(Enum): + """ + Enum class for api provider schema type. + """ + + OPENAPI = "openapi" + SWAGGER = "swagger" + OPENAI_PLUGIN = "openai_plugin" + OPENAI_ACTIONS = "openai_actions" + + @classmethod + def value_of(cls, value: str) -> "ApiProviderSchemaType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class ApiProviderAuthType(Enum): + """ + Enum class for api provider auth type. + """ + + NONE = "none" + API_KEY = "api_key" + + @classmethod + def value_of(cls, value: str) -> "ApiProviderAuthType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class ToolInvokeMessage(BaseModel): + class TextMessage(BaseModel): + text: str + + class JsonMessage(BaseModel): + json_object: dict + + class BlobMessage(BaseModel): + blob: bytes + + class FileMessage(BaseModel): + pass + + class VariableMessage(BaseModel): + variable_name: str = Field(..., description="The name of the variable") + variable_value: Any = Field(..., description="The value of the variable") + stream: bool = Field(default=False, description="Whether the variable is streamed") + + @model_validator(mode="before") + @classmethod + def transform_variable_value(cls, values) -> Any: + """ + Only basic types and lists are allowed. + """ + value = values.get("variable_value") + if not isinstance(value, dict | list | str | int | float | bool): + raise ValueError("Only basic types and lists are allowed.") + + # if stream is true, the value must be a string + if values.get("stream"): + if not isinstance(value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + + return values + + @field_validator("variable_name", mode="before") + @classmethod + def transform_variable_name(cls, value: str) -> str: + """ + The variable name must be a string. + """ + if value in {"json", "text", "files"}: + raise ValueError(f"The variable name '{value}' is reserved.") + return value + + class LogMessage(BaseModel): + class LogStatus(Enum): + START = "start" + ERROR = "error" + SUCCESS = "success" + + id: str + label: str = Field(..., description="The label of the log") + parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") + error: Optional[str] = Field(default=None, description="The error message") + status: LogStatus = Field(..., description="The status of the log") + data: Mapping[str, Any] = Field(..., description="Detailed log data") + metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + + class MessageType(Enum): + TEXT = "text" + IMAGE = "image" + LINK = "link" + BLOB = "blob" + JSON = "json" + IMAGE_LINK = "image_link" + BINARY_LINK = "binary_link" + VARIABLE = "variable" + FILE = "file" + LOG = "log" + + type: MessageType = MessageType.TEXT + """ + plain text, image url or link url + """ + message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + meta: dict[str, Any] | None = None + + @field_validator("message", mode="before") + @classmethod + def decode_blob_message(cls, v): + if isinstance(v, dict) and "blob" in v: + try: + v["blob"] = base64.b64decode(v["blob"]) + except Exception: + pass + return v + + @field_serializer("message") + def serialize_message(self, v): + if isinstance(v, self.BlobMessage): + return {"blob": base64.b64encode(v.blob).decode("utf-8")} + return v + + +class ToolInvokeMessageBinary(BaseModel): + mimetype: str = Field(..., description="The mimetype of the binary") + url: str = Field(..., description="The url of the binary") + file_var: Optional[dict[str, Any]] = None + + +class ToolParameter(PluginParameter): + """ + Overrides type + """ + + class ToolParameterType(enum.StrEnum): + """ + removes TOOLS_SELECTOR from PluginParameterType + """ + + STRING = PluginParameterType.STRING.value + NUMBER = PluginParameterType.NUMBER.value + BOOLEAN = PluginParameterType.BOOLEAN.value + SELECT = PluginParameterType.SELECT.value + SECRET_INPUT = PluginParameterType.SECRET_INPUT.value + FILE = PluginParameterType.FILE.value + FILES = PluginParameterType.FILES.value + APP_SELECTOR = PluginParameterType.APP_SELECTOR.value + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + + # deprecated, should not use. + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + + def as_normal_type(self): + return as_normal_type(self) + + def cast_value(self, value: Any): + return cast_parameter_value(self, value) + + class ToolParameterForm(Enum): + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM + + type: ToolParameterType = Field(..., description="The type of the parameter") + human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") + llm_description: Optional[str] = None + + @classmethod + def get_simple_instance( + cls, + name: str, + llm_description: str, + typ: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": + """ + get a simple tool parameter + + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param typ: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter + """ + # convert options to ToolParameterOption + # FIXME fix the type error + if options: + option_objs = [ + PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ] + else: + option_objs = [] + + return cls( + name=name, + label=I18nObject(en_US="", zh_Hans=""), + placeholder=None, + human_description=I18nObject(en_US="", zh_Hans=""), + type=typ, + form=cls.ToolParameterForm.LLM, + llm_description=llm_description, + required=required, + options=option_objs, + ) + + def init_frontend_parameter(self, value: Any): + return init_frontend_parameter(self, self.type, value) + + +class ToolProviderIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + description: I18nObject = Field(..., description="The description of the tool") + icon: str = Field(..., description="The icon of the tool") + label: I18nObject = Field(..., description="The label of the tool") + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + + +class ToolIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + provider: str = Field(..., description="The provider of the tool") + icon: Optional[str] = None + + +class ToolDescription(BaseModel): + human: I18nObject = Field(..., description="The description presented to the user") + llm: str = Field(..., description="The description presented to the LLM") + + +class ToolEntity(BaseModel): + identity: ToolIdentity + parameters: list[ToolParameter] = Field(default_factory=list) + description: Optional[ToolDescription] = None + output_schema: Optional[dict] = None + has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + @field_validator("parameters", mode="before") + @classmethod + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: + return v or [] + + +class ToolProviderEntity(BaseModel): + identity: ToolProviderIdentity + plugin_id: Optional[str] = None + credentials_schema: list[ProviderConfig] = Field(default_factory=list) + + +class ToolProviderEntityWithPlugin(ToolProviderEntity): + tools: list[ToolEntity] = Field(default_factory=list) + + +class WorkflowToolParameterConfiguration(BaseModel): + """ + Workflow tool configuration + """ + + name: str = Field(..., description="The name of the parameter") + description: str = Field(..., description="The description of the parameter") + form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + + +class ToolInvokeMeta(BaseModel): + """ + Tool invoke meta + """ + + time_cost: float = Field(..., description="The time cost of the tool invoke") + error: Optional[str] = None + tool_config: Optional[dict] = None + + @classmethod + def empty(cls) -> "ToolInvokeMeta": + """ + Get an empty instance of ToolInvokeMeta + """ + return cls(time_cost=0.0, error=None, tool_config={}) + + @classmethod + def error_instance(cls, error: str) -> "ToolInvokeMeta": + """ + Get an instance of ToolInvokeMeta with error + """ + return cls(time_cost=0.0, error=error, tool_config={}) + + def to_dict(self) -> dict: + return { + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, + } + + +class ToolLabel(BaseModel): + """ + Tool label + """ + + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + icon: str = Field(..., description="The icon of the tool") + + +class ToolInvokeFrom(Enum): + """ + Enum class for tool invoke + """ + + WORKFLOW = "workflow" + AGENT = "agent" + PLUGIN = "plugin" + + +class ToolSelector(BaseModel): + dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY + + class Parameter(BaseModel): + name: str = Field(..., description="The name of the parameter") + type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") + required: bool = Field(..., description="Whether the parameter is required") + description: str = Field(..., description="The description of the parameter") + default: Optional[Union[int, float, str]] = None + options: Optional[list[PluginParameterOption]] = None + + provider_id: str = Field(..., description="The id of the provider") + tool_name: str = Field(..., description="The name of the tool") + tool_description: str = Field(..., description="The description of the tool") + tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") + tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") + + def to_plugin_parameter(self) -> dict[str, Any]: + return self.model_dump() diff --git a/api/core/datasource/entities/values.py b/api/core/datasource/entities/values.py new file mode 100644 index 0000000000..f460df7e25 --- /dev/null +++ b/api/core/datasource/entities/values.py @@ -0,0 +1,111 @@ +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum + +ICONS = { + ToolLabelEnum.SEARCH: """ + +""", # noqa: E501 + ToolLabelEnum.IMAGE: """ + +""", # noqa: E501 + ToolLabelEnum.VIDEOS: """ + +""", # noqa: E501 + ToolLabelEnum.WEATHER: """ + +""", # noqa: E501 + ToolLabelEnum.FINANCE: """ + +""", # noqa: E501 + ToolLabelEnum.DESIGN: """ + +""", # noqa: E501 + ToolLabelEnum.TRAVEL: """ + +""", # noqa: E501 + ToolLabelEnum.SOCIAL: """ + +""", # noqa: E501 + ToolLabelEnum.NEWS: """ + +""", # noqa: E501 + ToolLabelEnum.MEDICAL: """ + +""", # noqa: E501 + ToolLabelEnum.PRODUCTIVITY: """ + +""", # noqa: E501 + ToolLabelEnum.EDUCATION: """ + +""", # noqa: E501 + ToolLabelEnum.BUSINESS: """ + +""", # noqa: E501 + ToolLabelEnum.ENTERTAINMENT: """ + +""", # noqa: E501 + ToolLabelEnum.UTILITIES: """ + +""", # noqa: E501 + ToolLabelEnum.OTHER: """ + +""", # noqa: E501 +} + +default_tool_label_dict = { + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), +} + +default_tool_labels = [v for k, v in default_tool_label_dict.items()] +default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/datasource/errors.py b/api/core/datasource/errors.py new file mode 100644 index 0000000000..c5f9ca4774 --- /dev/null +++ b/api/core/datasource/errors.py @@ -0,0 +1,37 @@ +from core.tools.entities.tool_entities import ToolInvokeMeta + + +class ToolProviderNotFoundError(ValueError): + pass + + +class ToolNotFoundError(ValueError): + pass + + +class ToolParameterValidationError(ValueError): + pass + + +class ToolProviderCredentialValidationError(ValueError): + pass + + +class ToolNotSupportedError(ValueError): + pass + + +class ToolInvokeError(ValueError): + pass + + +class ToolApiSchemaError(ValueError): + pass + + +class ToolEngineInvokeError(Exception): + meta: ToolInvokeMeta + + def __init__(self, meta, **kwargs): + self.meta = meta + super().__init__(**kwargs) diff --git a/api/core/datasource/plugin_tool/provider.py b/api/core/datasource/plugin_tool/provider.py new file mode 100644 index 0000000000..3616e426b9 --- /dev/null +++ b/api/core/datasource/plugin_tool/provider.py @@ -0,0 +1,79 @@ +from typing import Any + +from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.plugin_tool.tool import PluginTool + + +class PluginToolProviderController(BuiltinToolProviderController): + entity: ToolProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + self.entity = entity + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.PLUGIN + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + manager = PluginToolManager() + if not manager.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=user_id, + provider=self.entity.identity.name, + credentials=credentials, + ): + raise ToolProviderCredentialValidationError("Invalid credentials") + + def get_tool(self, tool_name: str) -> PluginTool: # type: ignore + """ + return tool with given name + """ + tool_entity = next( + (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None + ) + + if not tool_entity: + raise ValueError(f"Tool with name {tool_name} not found") + + return PluginTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + + def get_tools(self) -> list[PluginTool]: # type: ignore + """ + get all tools + """ + return [ + PluginTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + for tool_entity in self.entity.tools + ] diff --git a/api/core/datasource/plugin_tool/tool.py b/api/core/datasource/plugin_tool/tool.py new file mode 100644 index 0000000000..f31a9a0d3e --- /dev/null +++ b/api/core/datasource/plugin_tool/tool.py @@ -0,0 +1,89 @@ +from collections.abc import Generator +from typing import Any, Optional + +from core.plugin.manager.tool import PluginToolManager +from core.plugin.utils.converter import convert_parameters_to_plugin_format +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType + + +class PluginTool(Tool): + tenant_id: str + icon: str + plugin_unique_identifier: str + runtime_parameters: Optional[list[ToolParameter]] + + def __init__( + self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + self.runtime_parameters = None + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.PLUGIN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + manager = PluginToolManager() + + tool_parameters = convert_parameters_to_plugin_format(tool_parameters) + + yield from manager.invoke( + tenant_id=self.tenant_id, + user_id=user_id, + tool_provider=self.entity.identity.provider, + tool_name=self.entity.identity.name, + credentials=self.runtime.credentials, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + return PluginTool( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) + + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get the runtime parameters + """ + if not self.entity.has_runtime_parameters: + return self.entity.parameters + + if self.runtime_parameters is not None: + return self.runtime_parameters + + manager = PluginToolManager() + self.runtime_parameters = manager.get_runtime_parameters( + tenant_id=self.tenant_id, + user_id="", + provider=self.entity.identity.provider, + tool=self.entity.identity.name, + credentials=self.runtime.credentials, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + return self.runtime_parameters diff --git a/api/core/datasource/tool_engine.py b/api/core/datasource/tool_engine.py new file mode 100644 index 0000000000..ad0c62537c --- /dev/null +++ b/api/core/datasource/tool_engine.py @@ -0,0 +1,357 @@ +import json +from collections.abc import Generator, Iterable +from copy import deepcopy +from datetime import UTC, datetime +from mimetypes import guess_type +from typing import Any, Optional, Union, cast + +from yarl import URL + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.file import FileType +from core.file.models import FileTransferMethod +from core.ops.ops_trace_manager import TraceQueueManager +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolInvokeMeta, + ToolParameter, +) +from core.tools.errors import ( + ToolEngineInvokeError, + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.workflow_as_tool.tool import WorkflowTool +from extensions.ext_database import db +from models.enums import CreatedByRole +from models.model import Message, MessageFile + + +class ToolEngine: + """ + Tool runtime engine take care of the tool executions. + """ + + @staticmethod + def agent_invoke( + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, + agent_tool_callback: DifyAgentCallbackHandler, + trace_manager: Optional[TraceQueueManager] = None, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> tuple[str, list[str], ToolInvokeMeta]: + """ + Agent invokes the tool with the given arguments. + """ + # check if arguments is a string + if isinstance(tool_parameters, str): + # check if this tool has only one parameter + parameters = [ + parameter + for parameter in tool.get_runtime_parameters() + if parameter.form == ToolParameter.ToolParameterForm.LLM + ] + if parameters and len(parameters) == 1: + tool_parameters = {parameters[0].name: tool_parameters} + else: + try: + tool_parameters = json.loads(tool_parameters) + except Exception: + pass + if not isinstance(tool_parameters, dict): + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + + try: + # hit the callback handler + agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) + + messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id) + invocation_meta_dict: dict[str, ToolInvokeMeta] = {} + + def message_callback( + invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] + ): + for message in messages: + if isinstance(message, ToolInvokeMeta): + invocation_meta_dict["meta"] = message + else: + yield message + + messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=message_callback(invocation_meta_dict, messages), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=message.conversation_id, + ) + + message_list = list(messages) + + # extract binary data from tool invoke message + binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list) + # create message file + message_files = ToolEngine._create_message_files( + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id + ) + + plain_text = ToolEngine._convert_tool_response_to_str(message_list) + + meta = invocation_meta_dict["meta"] + + # hit the callback handler + agent_tool_callback.on_tool_end( + tool_name=tool.entity.identity.name, + tool_inputs=tool_parameters, + tool_outputs=plain_text, + message_id=message.id, + trace_manager=trace_manager, + ) + + # transform tool invoke message to get LLM friendly message + return plain_text, message_files, meta + except ToolProviderCredentialValidationError as e: + error_response = "Please check your tool provider credentials" + agent_tool_callback.on_tool_error(e) + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: + error_response = f"there is not a tool named {tool.entity.identity.name}" + agent_tool_callback.on_tool_error(e) + except ToolParameterValidationError as e: + error_response = f"tool parameters validation error: {e}, please check your tool parameters" + agent_tool_callback.on_tool_error(e) + except ToolInvokeError as e: + error_response = f"tool invoke error: {e}" + agent_tool_callback.on_tool_error(e) + except ToolEngineInvokeError as e: + meta = e.meta + error_response = f"tool invoke error: {meta.error}" + agent_tool_callback.on_tool_error(e) + return error_response, [], meta + except Exception as e: + error_response = f"unknown error: {e}" + agent_tool_callback.on_tool_error(e) + + return error_response, [], ToolInvokeMeta.error_instance(error_response) + + @staticmethod + def x( + tool: Tool, + tool_parameters: dict[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Workflow invokes the tool with the given arguments. + """ + try: + # hit the callback handler + workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) + + if isinstance(tool, WorkflowTool): + tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id + + if tool.runtime and tool.runtime.runtime_parameters: + tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} + + response = tool.invoke( + user_id=user_id, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + # hit the callback handler + response = workflow_tool_callback.on_tool_execution( + tool_name=tool.entity.identity.name, + tool_inputs=tool_parameters, + tool_outputs=response, + ) + + return response + except Exception as e: + workflow_tool_callback.on_tool_error(e) + raise e + + @staticmethod + def _invoke( + tool: Tool, + tool_parameters: dict, + user_id: str, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: + """ + Invoke the tool with the given arguments. + """ + started_at = datetime.now(UTC) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.entity.identity.name, + "tool_provider": tool.entity.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.entity.identity.icon, + }, + ) + try: + yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id) + except Exception as e: + meta.error = str(e) + raise ToolEngineInvokeError(meta) + finally: + ended_at = datetime.now(UTC) + meta.time_cost = (ended_at - started_at).total_seconds() + yield meta + + @staticmethod + def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: + """ + Handle tool response + """ + result = "" + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + result += cast(ToolInvokeMessage.TextMessage, response.message).text + elif response.type == ToolInvokeMessage.MessageType.LINK: + result += ( + f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}." + + " please tell user to check it." + ) + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + result += ( + "image has been created and sent to user already, " + + "you do not need to create it, just tell the user to check it now." + ) + elif response.type == ToolInvokeMessage.MessageType.JSON: + result = json.dumps( + cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + ) + else: + result += str(response.message) + + return result + + @staticmethod + def _extract_tool_response_binary_and_text( + tool_response: list[ToolInvokeMessage], + ) -> Generator[ToolInvokeMessageBinary, None, None]: + """ + Extract tool response binary + """ + for response in tool_response: + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + mimetype = None + if not response.meta: + raise ValueError("missing meta data") + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") + else: + try: + url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) + extension = url.suffix + guess_type_result, _ = guess_type(f"a{extension}") + if guess_type_result: + mimetype = guess_type_result + except Exception: + pass + + if not mimetype: + mimetype = "image/jpeg" + + yield ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "image/jpeg"), + url=cast(ToolInvokeMessage.TextMessage, response.message).text, + ) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + if not response.meta: + raise ValueError("missing meta data") + + yield ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "application/octet-stream"), + url=cast(ToolInvokeMessage.TextMessage, response.message).text, + ) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and "mime_type" in response.meta: + yield ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "application/octet-stream") + if response.meta + else "application/octet-stream", + url=cast(ToolInvokeMessage.TextMessage, response.message).text, + ) + + @staticmethod + def _create_message_files( + tool_messages: Iterable[ToolInvokeMessageBinary], + agent_message: Message, + invoke_from: InvokeFrom, + user_id: str, + ) -> list[str]: + """ + Create message file + + :return: message file ids + """ + result = [] + + for message in tool_messages: + if "image" in message.mimetype: + file_type = FileType.IMAGE + elif "video" in message.mimetype: + file_type = FileType.VIDEO + elif "audio" in message.mimetype: + file_type = FileType.AUDIO + elif "text" in message.mimetype or "pdf" in message.mimetype: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + + # extract tool file id from url + tool_file_id = message.url.split("/")[-1].split(".")[0] + message_file = MessageFile( + message_id=agent_message.id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", + url=message.url, + upload_file_id=tool_file_id, + created_by_role=( + CreatedByRole.ACCOUNT + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER + ), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + result.append(message_file.id) + + db.session.close() + + return result diff --git a/api/core/datasource/tool_file_manager.py b/api/core/datasource/tool_file_manager.py new file mode 100644 index 0000000000..7e8d4280d4 --- /dev/null +++ b/api/core/datasource/tool_file_manager.py @@ -0,0 +1,234 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from mimetypes import guess_extension, guess_type +from typing import Optional, Union +from uuid import uuid4 + +import httpx + +from configs import dify_config +from core.helper import ssrf_proxy +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import MessageFile +from models.tools import ToolFile + +logger = logging.getLogger(__name__) + + +class ToolFileManager: + @staticmethod + def sign_file(tool_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @staticmethod + def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def create_file_by_raw( + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, + filename: Optional[str] = None, + ) -> ToolFile: + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + unique_filename = f"{unique_name}{extension}" + # default just as before + present_filename = unique_filename + if filename is not None: + has_extension = len(filename.split(".")) > 1 + # Add extension flexibly + present_filename = filename if has_extension else f"{filename}{extension}" + filepath = f"tools/{tenant_id}/{unique_filename}" + storage.save(filepath, file_binary) + + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=present_filename, + size=len(file_binary), + ) + + db.session.add(tool_file) + db.session.commit() + db.session.refresh(tool_file) + + return tool_file + + @staticmethod + def create_file_by_url( + user_id: str, + tenant_id: str, + file_url: str, + conversation_id: Optional[str] = None, + ) -> ToolFile: + # try to download image + try: + response = ssrf_proxy.get(file_url) + response.raise_for_status() + blob = response.content + except httpx.TimeoutException: + raise ValueError(f"timeout when downloading file from {file_url}") + + mimetype = ( + guess_type(file_url)[0] + or response.headers.get("Content-Type", "").split(";")[0].strip() + or "application/octet-stream" + ) + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) + + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + original_url=file_url, + name=filename, + size=len(blob), + ) + + db.session.add(tool_file) + db.session.commit() + + return tool_file + + @staticmethod + def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile | None = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) + + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None + else: + tool_file_id = None + + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_generator_by_tool_file_id(tool_file_id: str): + """ + get file binary + + :param tool_file_id: the id of the tool file + + :return: the binary of the file, mime type + """ + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None, None + + stream = storage.load_stream(tool_file.file_key) + + return stream, tool_file + + +# init tool_file_parser +from core.file.tool_file_parser import tool_file_manager + +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/datasource/tool_label_manager.py b/api/core/datasource/tool_label_manager.py new file mode 100644 index 0000000000..4787d7d79c --- /dev/null +++ b/api/core/datasource/tool_label_manager.py @@ -0,0 +1,101 @@ +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.entities.values import default_tool_label_name_list +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from extensions.ext_database import db +from models.tools import ToolLabelBinding + + +class ToolLabelManager: + @classmethod + def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: + """ + Filter tool labels + """ + tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] + return list(set(tool_labels)) + + @classmethod + def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + """ + Update tool labels + """ + labels = cls.filter_tool_labels(labels) + + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + else: + raise ValueError("Unsupported tool type") + + # delete old labels + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() + + # insert new labels + for label in labels: + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) + + db.session.commit() + + @classmethod + def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: + """ + Get tool labels + """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + elif isinstance(controller, BuiltinToolProviderController): + return controller.tool_labels + else: + raise ValueError("Unsupported tool type") + + labels = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) + + return [label.label_name for label in labels] + + @classmethod + def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: + """ + Get tools labels + + :param tool_providers: list of tool providers + + :return: dict of tool labels + :key: tool id + :value: list of tool labels + """ + if not tool_providers: + return {} + + for controller in tool_providers: + if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + raise ValueError("Unsupported tool type") + + provider_ids = [] + for controller in tool_providers: + assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) + provider_ids.append(controller.provider_id) + + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) + + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} + + for label in labels: + tool_labels[label.tool_id].append(label.label_name) + + return tool_labels diff --git a/api/core/datasource/tool_manager.py b/api/core/datasource/tool_manager.py new file mode 100644 index 0000000000..f2d0b74f7c --- /dev/null +++ b/api/core/datasource/tool_manager.py @@ -0,0 +1,870 @@ +import json +import logging +import mimetypes +from collections.abc import Generator +from os import listdir, path +from threading import Lock +from typing import TYPE_CHECKING, Any, Union, cast + +from yarl import URL + +import contexts +from core.plugin.entities.plugin import ToolProviderID +from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + +if TYPE_CHECKING: + from core.workflow.nodes.tool.entities import ToolEntity + + +from configs import dify_config +from core.agent.entities import AgentToolEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool import Tool +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeFrom, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.utils.configuration import ( + ProviderConfigEncrypter, + ToolParameterConfigurationManager, +) +from core.tools.workflow_as_tool.tool import WorkflowTool +from extensions.ext_database import db +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class ToolManager: + _builtin_provider_lock = Lock() + _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} + _builtin_providers_loaded = False + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + + @classmethod + def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: + """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: + # init the builtin providers + cls.load_hardcoded_providers_cache() + + return cls._hardcoded_providers[provider] + + @classmethod + def get_builtin_provider( + cls, provider: str, tenant_id: str + ) -> BuiltinToolProviderController | PluginToolProviderController: + """ + get the builtin provider + + :param provider: the name of the provider + :param tenant_id: the id of the tenant + :return: the provider + """ + # split provider to + + if len(cls._hardcoded_providers) == 0: + # init the builtin providers + cls.load_hardcoded_providers_cache() + + if provider not in cls._hardcoded_providers: + # get plugin provider + plugin_provider = cls.get_plugin_provider(provider, tenant_id) + if plugin_provider: + return plugin_provider + + return cls._hardcoded_providers[provider] + + @classmethod + def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController: + """ + get the plugin provider + """ + # check if context is set + try: + contexts.plugin_tool_providers.get() + except LookupError: + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(Lock()) + + with contexts.plugin_tool_providers_lock.get(): + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + + manager = PluginToolManager() + provider_entity = manager.fetch_tool_provider(tenant_id, provider) + if not provider_entity: + raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + + controller = PluginToolProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + + plugin_tool_providers[provider] = controller + + return controller + + @classmethod + def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: + """ + get the builtin tool + + :param provider: the name of the provider + :param tool_name: the name of the tool + :param tenant_id: the id of the tenant + :return: the provider, the tool + """ + provider_controller = cls.get_builtin_provider(provider, tenant_id) + tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + + return tool + + @classmethod + def get_tool_runtime( + cls, + provider_type: ToolProviderType, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: + """ + get the tool runtime + + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :param tool_name: the name of the tool + :param tenant_id: the tenant id + :param invoke_from: invoke from + :param tool_invoke_from: the tool invoke from + + :return: the tool + """ + if provider_type == ToolProviderType.BUILT_IN: + # check if the builtin tool need credentials + provider_controller = cls.get_builtin_provider(provider_id, tenant_id) + + builtin_tool = provider_controller.get_tool(tool_name) + if not builtin_tool: + raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + + if not provider_controller.need_credentials: + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + + if isinstance(provider_controller, PluginToolProviderController): + provider_id_entity = ToolProviderID(provider_id) + # get credentials + builtin_provider: BuiltinToolProvider | None = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .first() + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + else: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .first() + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + + # decrypt the credentials + credentials = builtin_provider.credentials + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + decrypted_credentials = tool_configuration.decrypt(credentials) + + return cast( + BuiltinTool, + builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=decrypted_credentials, + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + + elif provider_type == ToolProviderType.API: + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) + + # decrypt the credentials + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], + provider_type=api_provider.provider_type.value, + provider_identity=api_provider.entity.identity.name, + ) + decrypted_credentials = tool_configuration.decrypt(credentials) + + return cast( + ApiTool, + api_provider.get_tool(tool_name).fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=decrypted_credentials, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + elif provider_type == ToolProviderType.WORKFLOW: + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + return cast( + WorkflowTool, + controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ), + ) + elif provider_type == ToolProviderType.APP: + raise NotImplementedError("app provider not implemented") + elif provider_type == ToolProviderType.PLUGIN: + return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + else: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + + @classmethod + def get_agent_tool_runtime( + cls, + tenant_id: str, + app_id: str, + agent_tool: AgentToolEntity, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: + """ + get the agent tool runtime + """ + tool_entity = cls.get_tool_runtime( + provider_type=agent_tool.provider_type, + provider_id=agent_tool.provider_id, + tool_name=agent_tool.tool_name, + tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.AGENT, + ) + runtime_parameters = {} + parameters = tool_entity.get_merged_runtime_parameters() + for parameter in parameters: + # check file types + if ( + parameter.type + in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + } + and parameter.required + ): + raise ValueError(f"file type parameter {parameter.name} not supported in agent") + + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # save tool parameter to tool entity memory + value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=agent_tool.provider_id, + provider_type=agent_tool.provider_type, + identity_id=f"AGENT.{app_id}", + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + + @classmethod + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: + """ + get the workflow tool runtime + """ + tool_runtime = cls.get_tool_runtime( + provider_type=workflow_tool.provider_type, + provider_id=workflow_tool.provider_id, + tool_name=workflow_tool.tool_name, + tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + ) + runtime_parameters = {} + parameters = tool_runtime.get_merged_runtime_parameters() + + for parameter in parameters: + # save tool parameter to tool entity memory + if parameter.form == ToolParameter.ToolParameterForm.FORM: + value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_runtime, + provider_name=workflow_tool.provider_id, + provider_type=workflow_tool.provider_type, + identity_id=f"WORKFLOW.{app_id}.{node_id}", + ) + + if runtime_parameters: + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_runtime.runtime.runtime_parameters.update(runtime_parameters) + return tool_runtime + + @classmethod + def get_tool_runtime_from_plugin( + cls, + tool_type: ToolProviderType, + tenant_id: str, + provider: str, + tool_name: str, + tool_parameters: dict[str, Any], + ) -> Tool: + """ + get tool runtime from plugin + """ + tool_entity = cls.get_tool_runtime( + provider_type=tool_type, + provider_id=provider, + tool_name=tool_name, + tenant_id=tenant_id, + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + ) + runtime_parameters = {} + parameters = tool_entity.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # save tool parameter to tool entity memory + value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name)) + runtime_parameters[parameter.name] = value + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + + @classmethod + def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]: + """ + get the absolute path of the icon of the hardcoded provider + + :param provider: the name of the provider + :return: the absolute path of the icon, the mime type of the icon + """ + # get provider + provider_controller = cls.get_hardcoded_provider(provider) + + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "builtin_tool", + "providers", + provider, + "_assets", + provider_controller.entity.identity.icon, + ) + # check if the icon exists + if not path.exists(absolute_path): + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") + + # get the mime type + mime_type, _ = mimetypes.guess_type(absolute_path) + mime_type = mime_type or "application/octet-stream" + + return absolute_path, mime_type + + @classmethod + def list_hardcoded_providers(cls): + # use cache first + if cls._builtin_providers_loaded: + yield from list(cls._hardcoded_providers.values()) + return + + with cls._builtin_provider_lock: + if cls._builtin_providers_loaded: + yield from list(cls._hardcoded_providers.values()) + return + + yield from cls._list_hardcoded_providers() + + @classmethod + def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]: + """ + list all the plugin providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_tool_providers(tenant_id) + return [ + PluginToolProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + for provider in provider_entities + ] + + @classmethod + def list_builtin_providers( + cls, tenant_id: str + ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]: + """ + list all the builtin providers + """ + yield from cls.list_hardcoded_providers() + # get plugin providers + yield from cls.list_plugin_providers(tenant_id) + + @classmethod + def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: + """ + list all the builtin providers + """ + for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")): + if provider_path.startswith("__"): + continue + + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)): + if provider_path.startswith("__"): + continue + + # init provider + try: + provider_class = load_single_subclass_from_source( + module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}", + script_path=path.join( + path.dirname(path.realpath(__file__)), + "builtin_tool", + "providers", + provider_path, + f"{provider_path}.py", + ), + parent_type=BuiltinToolProviderController, + ) + provider: BuiltinToolProviderController = provider_class() + cls._hardcoded_providers[provider.entity.identity.name] = provider + for tool in provider.get_tools(): + cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label + yield provider + + except Exception: + logger.exception(f"load builtin provider {provider}") + continue + # set builtin providers loaded + cls._builtin_providers_loaded = True + + @classmethod + def load_hardcoded_providers_cache(cls): + for _ in cls.list_hardcoded_providers(): + pass + + @classmethod + def clear_hardcoded_providers_cache(cls): + cls._hardcoded_providers = {} + cls._builtin_providers_loaded = False + + @classmethod + def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: + """ + get the tool label + + :param tool_name: the name of the tool + + :return: the label of the tool + """ + if len(cls._builtin_tools_labels) == 0: + # init the builtin providers + cls.load_hardcoded_providers_cache() + + if tool_name not in cls._builtin_tools_labels: + return None + + return cls._builtin_tools_labels[tool_name] + + @classmethod + def list_providers_from_api( + cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral + ) -> list[ToolProviderApiEntity]: + result_providers: dict[str, ToolProviderApiEntity] = {} + + filters = [] + if not typ: + filters.extend(["builtin", "api", "workflow"]) + else: + filters.append(typ) + + with db.session.no_autoflush: + if "builtin" in filters: + # get builtin providers + builtin_providers = cls.list_builtin_providers(tenant_id) + + # get db builtin providers + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) + + # rewrite db_builtin_providers + for db_provider in db_builtin_providers: + tool_provider_id = str(ToolProviderID(db_provider.provider)) + db_provider.provider = tool_provider_id + + def find_db_builtin_provider(provider): + return next((x for x in db_builtin_providers if x.provider == provider), None) + + # append builtin providers + for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), + data=provider, + name_func=lambda x: x.identity.name, + ): + continue + + user_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider, + db_provider=find_db_builtin_provider(provider.entity.identity.name), + decrypt_credentials=False, + ) + + if isinstance(provider, PluginToolProviderController): + result_providers[f"plugin_provider.{user_provider.name}"] = user_provider + else: + result_providers[f"builtin_provider.{user_provider.name}"] = user_provider + + # get db api providers + + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) + + api_provider_controllers: list[dict[str, Any]] = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] + + # get labels + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) + + for api_provider_controller in api_provider_controllers: + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], + decrypt_credentials=False, + labels=labels.get(api_provider_controller["controller"].provider_id, []), + ) + result_providers[f"api_provider.{user_provider.name}"] = user_provider + + if "workflow" in filters: + # get workflow providers + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) + + workflow_provider_controllers: list[WorkflowToolProviderController] = [] + for provider in workflow_providers: + try: + workflow_provider_controllers.append( + ToolTransformService.workflow_provider_to_controller(db_provider=provider) + ) + except Exception: + # app has been deleted + pass + + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) + + for provider_controller in workflow_provider_controllers: + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=provider_controller, + labels=labels.get(provider_controller.provider_id, []), + ) + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider + + return BuiltinToolProviderSort.sort(list(result_providers.values())) + + @classmethod + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: + """ + get the api provider + + :param tenant_id: the id of the tenant + :param provider_id: the id of the provider + + :return: the provider controller, the credentials + """ + provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) + + if provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + + controller = ApiToolProviderController.from_db( + provider, + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, + ) + controller.load_bundled_tools(provider.tools) + + return controller, provider.credentials + + @classmethod + def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: + """ + get api provider + """ + """ + get tool provider + """ + provider_name = provider + provider_obj: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) + + if provider_obj is None: + raise ValueError(f"you have not added provider {provider_name}") + + try: + credentials = json.loads(provider_obj.credentials_str) or {} + except Exception: + credentials = {} + + # package tool provider controller + controller = ApiToolProviderController.from_db( + provider_obj, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, + ) + # init tool configuration + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + + decrypted_credentials = tool_configuration.decrypt(credentials) + masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + + try: + icon = json.loads(provider_obj.icon) + except Exception: + icon = {"background": "#252525", "content": "\ud83d\ude01"} + + # add tool labels + labels = ToolLabelManager.get_tool_labels(controller) + + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_obj.schema_type, + "schema": provider_obj.schema, + "tools": provider_obj.tools, + "icon": icon, + "description": provider_obj.description, + "credentials": masked_credentials, + "privacy_policy": provider_obj.privacy_policy, + "custom_disclaimer": provider_obj.custom_disclaimer, + "labels": labels, + } + ), + ) + + @classmethod + def generate_builtin_tool_icon_url(cls, provider_id: str) -> str: + return str( + URL(dify_config.CONSOLE_API_URL or "/") + / "console" + / "api" + / "workspaces" + / "current" + / "tool-provider" + / "builtin" + / provider_id + / "icon" + ) + + @classmethod + def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str: + return str( + URL(dify_config.CONSOLE_API_URL or "/") + / "console" + / "api" + / "workspaces" + / "current" + / "plugin" + / "icon" + % {"tenant_id": tenant_id, "filename": filename} + ) + + @classmethod + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + try: + workflow_provider: WorkflowToolProvider | None = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + icon: dict = json.loads(workflow_provider.icon) + return icon + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + + @classmethod + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + try: + api_provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) + + if api_provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + + icon: dict = json.loads(api_provider.icon) + return icon + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + + @classmethod + def get_tool_icon( + cls, + tenant_id: str, + provider_type: ToolProviderType, + provider_id: str, + ) -> Union[str, dict]: + """ + get the tool icon + + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: + """ + provider_type = provider_type + provider_id = provider_id + if provider_type == ToolProviderType.BUILT_IN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_builtin_tool_icon_url(provider_id) + elif provider_type == ToolProviderType.API: + return cls.generate_api_tool_icon_url(tenant_id, provider_id) + elif provider_type == ToolProviderType.WORKFLOW: + return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) + elif provider_type == ToolProviderType.PLUGIN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + raise ValueError(f"plugin provider {provider_id} not found") + else: + raise ValueError(f"provider type {provider_type} not found") + + +ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/datasource/utils/__init__.py b/api/core/datasource/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/datasource/utils/configuration.py b/api/core/datasource/utils/configuration.py new file mode 100644 index 0000000000..6a5fba65bd --- /dev/null +++ b/api/core/datasource/utils/configuration.py @@ -0,0 +1,265 @@ +from copy import deepcopy +from typing import Any + +from pydantic import BaseModel + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType +from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderType, +) + + +class ProviderConfigEncrypter(BaseModel): + tenant_id: str + config: list[BasicProviderConfig] + provider_type: str + provider_identity: str + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cached_credentials = cache.get() + if cached_credentials: + return cached_credentials + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + cache.set(data) + return data + + def delete_tool_credentials_cache(self): + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cache.delete() + + +class ToolParameterConfigurationManager: + """ + Tool parameter configuration manager + """ + + tenant_id: str + tool_runtime: Tool + provider_name: str + provider_type: ToolProviderType + identity_id: str + + def __init__( + self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str + ) -> None: + self.tenant_id = tenant_id + self.tool_runtime = tool_runtime + self.provider_name = provider_name + self.provider_type = provider_type + self.identity_id = identity_id + + def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + deep copy parameters + """ + return deepcopy(parameters) + + def _merge_parameters(self) -> list[ToolParameter]: + """ + merge parameters + """ + # get tool parameters + tool_parameters = self.tool_runtime.entity.parameters or [] + # get tool runtime parameters + runtime_parameters = self.tool_runtime.get_runtime_parameters() + # override parameters + current_parameters = tool_parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return current_parameters + + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + mask tool parameters + + return a deep copy of parameters with masked values + """ + parameters = self._deep_copy(parameters) + + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + if len(parameters[parameter.name]) > 6: + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) + else: + parameters[parameter.name] = "*" * len(parameters[parameter.name]) + + return parameters + + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + encrypt tool parameters with tenant id + + return a deep copy of parameters with encrypted values + """ + # override parameters + current_parameters = self._merge_parameters() + + parameters = self._deep_copy(parameters) + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) + parameters[parameter.name] = encrypted + + return parameters + + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + decrypt tool parameters with tenant id + + return a deep copy of parameters with decrypted values + """ + + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id, + ) + cached_parameters = cache.get() + if cached_parameters: + return cached_parameters + + # override parameters + current_parameters = self._merge_parameters() + has_secret_input = False + + for parameter in current_parameters: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): + if parameter.name in parameters: + try: + has_secret_input = True + parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) + except Exception: + pass + + if has_secret_input: + cache.set(parameters) + + return parameters + + def delete_tool_parameters_cache(self): + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f"{self.provider_type.value}.{self.provider_name}", + tool_name=self.tool_runtime.entity.identity.name, + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id, + ) + cache.delete() diff --git a/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py new file mode 100644 index 0000000000..032274b87e --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -0,0 +1,199 @@ +import threading +from typing import Any + +from flask import Flask, current_app +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model: dict[str, Any] = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): + """Tool for querying multi dataset.""" + + name: str = "dataset_" + args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + dataset_ids: list[str] + reranking_provider_name: str + reranking_model_name: str + + @classmethod + def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): + return cls( + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents: list[RagDocument] = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.reranking_provider_name, + model_type=ModelType.RERANK, + model=self.reranking_model_name, + ) + + rerank_runner = RerankModelRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(all_documents) + + document_score_list = {} + for item in all_documents: + if item.metadata and item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + + document_context_list = [] + index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") + else: + document_context_list.append(segment.get_sign_content()) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), + "doc_metadata": document.doc_metadata, + } + + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + return "" + + raise RuntimeError("not segments found") + + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): + with flask_app.app_context(): + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) + + if not dataset: + return [] + + for hit_callback in hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + ) + if documents: + all_documents.extend(documents) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + all_documents.extend(documents) diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py new file mode 100644 index 0000000000..a4d2de3b1c --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -0,0 +1,33 @@ +from abc import abstractmethod +from typing import Any, Optional + +from msal_extensions.persistence import ABC # type: ignore +from pydantic import BaseModel, ConfigDict + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler + + +class DatasetRetrieverBaseTool(BaseModel, ABC): + """Tool for querying a Dataset.""" + + name: str = "dataset" + description: str = "use this to retrieve a dataset. " + tenant_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py new file mode 100644 index 0000000000..63260cfac3 --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py @@ -0,0 +1,202 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.context_entities import DocumentContext +from core.rag.models.document import Document as RetrievalDocument +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from extensions.ext_database import db +from models.dataset import Dataset +from models.dataset import Document as DatasetDocument +from services.external_knowledge_service import ExternalDatasetService + +default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, +} + + +class DatasetRetrieverToolInput(BaseModel): + query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") + + +class DatasetRetrieverTool(DatasetRetrieverBaseTool): + """Tool for querying a Dataset.""" + + name: str = "dataset" + args_schema: type[BaseModel] = DatasetRetrieverToolInput + description: str = "use this to retrieve a dataset. " + dataset_id: str + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs): + description = dataset.description + if not description: + description = "useful for when you want to answer queries about the " + dataset.name + + description = description.replace("\n", "").replace("\r", "") + return cls( + name=f"dataset_{dataset.id.replace('-', '_')}", + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + description=description, + **kwargs, + ) + + def _run(self, query: str) -> str: + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) + + if not dataset: + return "" + for hit_callback in self.hit_callbacks: + hit_callback.on_query(query, dataset.id) + if dataset.provider == "external": + results = [] + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + query=query, + external_retrieval_parameters=dataset.retrieval_model, + ) + for external_document in external_documents: + document = RetrievalDocument( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) + # deal with external documents + context_list = [] + for position, item in enumerate(results, start=1): + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } + context_list.append(source) + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join([item.page_content for item in results])) + else: + # get retrieval model , if the model is not setting , using default + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model") + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights"), + ) + else: + documents = [] + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if item.metadata is not None and item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + document_context_list = [] + records = RetrievalService.format_retrieval_documents(documents) + if records: + for record in records: + segment = record.segment + if segment.answer: + document_context_list.append( + DocumentContext( + content=f"question:{segment.get_sign_content()} answer:{segment.answer}", + score=record.score, + ) + ) + else: + document_context_list.append( + DocumentContext( + content=segment.get_sign_content(), + score=record.score, + ) + ) + retrieval_resource_list = [] + if self.return_resource: + for record in records: + segment = record.segment + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).first() + if dataset and document: + source = { + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, # type: ignore + "document_name": document.name, # type: ignore + "data_source_type": document.data_source_type, # type: ignore + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": record.score or 0.0, + "doc_metadata": document.doc_metadata, # type: ignore + } + + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + retrieval_resource_list.append(source) + + if self.return_resource and retrieval_resource_list: + retrieval_resource_list = sorted( + retrieval_resource_list, + key=lambda x: x.get("score") or 0.0, + reverse=True, + ) + for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore + item["position"] = position # type: ignore + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(retrieval_resource_list) + if document_context_list: + document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) + return str("\n".join([document_context.content for document_context in document_context_list])) + return "" diff --git a/api/core/datasource/utils/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever_tool.py new file mode 100644 index 0000000000..b73dec4ebc --- /dev/null +++ b/api/core/datasource/utils/dataset_retriever_tool.py @@ -0,0 +1,134 @@ +from collections.abc import Generator +from typing import Any, Optional + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool + + +class DatasetRetrieverTool(Tool): + retrieval_tool: DatasetRetrieverBaseTool + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: + super().__init__(entity, runtime) + self.retrieval_tool = retrieval_tool + + @staticmethod + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity | None, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: + """ + get dataset tool + """ + # check if retrieve_config is valid + if dataset_ids is None or len(dataset_ids) == 0: + return [] + if retrieve_config is None: + return [] + + feature = DatasetRetrieval() + + # save original retrieve strategy, and set retrieve strategy to SINGLE + # Agent only support SINGLE mode + original_retriever_mode = retrieve_config.retrieve_strategy + retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + retrieval_tools = feature.to_dataset_retriever_tool( + tenant_id=tenant_id, + dataset_ids=dataset_ids, + retrieve_config=retrieve_config, + return_resource=return_resource, + invoke_from=invoke_from, + hit_callback=hit_callback, + ) + if retrieval_tools is None or len(retrieval_tools) == 0: + return [] + + # restore retrieve strategy + retrieve_config.retrieve_strategy = original_retriever_mode + + # convert retrieval tools to Tools + tools = [] + for retrieval_tool in retrieval_tools: + tool = DatasetRetrieverTool( + retrieval_tool=retrieval_tool, + entity=ToolEntity( + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), + parameters=[], + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + ), + runtime=ToolRuntime(tenant_id=tenant_id), + ) + + tools.append(tool) + + return tools + + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + return [ + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + placeholder=I18nObject(en_US="", zh_Hans=""), + ), + ] + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.DATASET_RETRIEVAL + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + """ + invoke dataset retriever tool + """ + query = tool_parameters.get("query") + if not query: + yield self.create_text_message(text="please input query") + else: + # invoke dataset retriever tool + result = self.retrieval_tool._run(query=query) + yield self.create_text_message(text=result) + + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: + """ + validate the credentials for dataset retriever tool + """ + pass diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py new file mode 100644 index 0000000000..6fd0c201e3 --- /dev/null +++ b/api/core/datasource/utils/message_transformer.py @@ -0,0 +1,121 @@ +import logging +from collections.abc import Generator +from mimetypes import guess_extension +from typing import Optional + +from core.file import File, FileTransferMethod, FileType +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_file_manager import ToolFileManager + +logger = logging.getLogger(__name__) + + +class ToolFileMessageTransformer: + @classmethod + def transform_tool_invoke_messages( + cls, + messages: Generator[ToolInvokeMessage, None, None], + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Transform tool message and handle file download + """ + for message in messages: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: + yield message + elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance( + message.message, ToolInvokeMessage.TextMessage + ): + # try to download image + try: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + file = ToolFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=message.message.text, + conversation_id=conversation_id, + ) + + url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" + + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=message.meta.copy() if message.meta is not None else {}, + ) + except Exception as e: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=ToolInvokeMessage.TextMessage( + text=f"Failed to download image: {message.message.text}: {e}" + ), + meta=message.meta.copy() if message.meta is not None else {}, + ) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + meta = message.meta or {} + + mimetype = meta.get("mime_type", "application/octet-stream") + # get filename from meta + filename = meta.get("file_name", None) + # if message is str, encode it to bytes + + if not isinstance(message.message, ToolInvokeMessage.BlobMessage): + raise ValueError("unexpected message type") + + # FIXME: should do a type check here. + assert isinstance(message.message.blob, bytes) + file = ToolFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message.blob, + mimetype=mimetype, + filename=filename, + ) + + url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) + + # check if file is image + if "image" in mimetype: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BINARY_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + elif message.type == ToolInvokeMessage.MessageType.FILE: + meta = message.meta or {} + file = meta.get("file", None) + if isinstance(file, File): + if file.transfer_method == FileTransferMethod.TOOL_FILE: + assert file.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + if file.type == FileType.IMAGE: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield message + else: + yield message + + @classmethod + def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: + return f"/files/tools/{tool_file_id}{extension or '.bin'}" diff --git a/api/core/datasource/utils/model_invocation_utils.py b/api/core/datasource/utils/model_invocation_utils.py new file mode 100644 index 0000000000..3f59b3f472 --- /dev/null +++ b/api/core/datasource/utils/model_invocation_utils.py @@ -0,0 +1,169 @@ +""" +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. + +Therefore, a model manager is needed to list/invoke/validate models. +""" + +import json +from typing import Optional, cast + +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db +from models.tools import ToolModelInvoke + + +class InvokeModelError(Exception): + pass + + +class ModelInvocationUtils: + @staticmethod + def get_max_llm_context_tokens( + tenant_id: str, + ) -> int: + """ + get max llm context tokens of the model + """ + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + + if not model_instance: + raise InvokeModelError("Model not found") + + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + if not schema: + raise InvokeModelError("No model schema found") + + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + if max_tokens is None: + return 2048 + + return max_tokens + + @staticmethod + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + """ + calculate tokens from prompt messages and model parameters + """ + + # get model instance + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) + + if not model_instance: + raise InvokeModelError("Model not found") + + # get tokens + tokens = model_instance.get_llm_num_tokens(prompt_messages) + + return tokens + + @staticmethod + def invoke( + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + ) -> LLMResult: + """ + invoke model with parameters in user's own context + + :param user_id: user id + :param tenant_id: tenant id, the tenant id of the creator of the tool + :param tool_type: tool type + :param tool_name: tool name + :param prompt_messages: prompt messages + :return: AssistantPromptMessage + """ + + # get model manager + model_manager = ModelManager() + # get model instance + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + + # get prompt tokens + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + model_parameters = { + "temperature": 0.8, + "top_p": 0.8, + } + + # create tool model invoke + tool_model_invoke = ToolModelInvoke( + user_id=user_id, + tenant_id=tenant_id, + provider=model_instance.provider, + tool_type=tool_type, + tool_name=tool_name, + model_parameters=json.dumps(model_parameters), + prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), + model_response="", + prompt_tokens=prompt_tokens, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency="USD", + ) + + db.session.add(tool_model_invoke) + db.session.commit() + + try: + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), + ) + except InvokeRateLimitError as e: + raise InvokeModelError(f"Invoke rate limit error: {e}") + except InvokeBadRequestError as e: + raise InvokeModelError(f"Invoke bad request error: {e}") + except InvokeConnectionError as e: + raise InvokeModelError(f"Invoke connection error: {e}") + except InvokeAuthorizationError as e: + raise InvokeModelError("Invoke authorization error") + except InvokeServerUnavailableError as e: + raise InvokeModelError(f"Invoke server unavailable error: {e}") + except Exception as e: + raise InvokeModelError(f"Invoke error: {e}") + + # update tool model invoke + tool_model_invoke.model_response = response.message.content + if response.usage: + tool_model_invoke.answer_tokens = response.usage.completion_tokens + tool_model_invoke.answer_unit_price = response.usage.completion_unit_price + tool_model_invoke.answer_price_unit = response.usage.completion_price_unit + tool_model_invoke.provider_response_latency = response.usage.latency + tool_model_invoke.total_price = response.usage.total_price + tool_model_invoke.currency = response.usage.currency + + db.session.commit() + + return response diff --git a/api/core/datasource/utils/parser.py b/api/core/datasource/utils/parser.py new file mode 100644 index 0000000000..f72291783a --- /dev/null +++ b/api/core/datasource/utils/parser.py @@ -0,0 +1,389 @@ +import re +import uuid +from json import dumps as json_dumps +from json import loads as json_loads +from json.decoder import JSONDecodeError +from typing import Optional + +from flask import request +from requests import get +from yaml import YAMLError, safe_load # type: ignore + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError + + +class ApiBasedToolSchemaParser: + @staticmethod + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + # set description to extra_info + extra_info["description"] = openapi["info"].get("description", "") + + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") + + server_url = openapi["servers"][0]["url"] + request_env = request.headers.get("X-Request-Env") + if request_env: + matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env] + server_url = matched_servers[0] if matched_servers else server_url + + # list all interfaces + interfaces = [] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] + for method in methods: + if method in path_item: + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) + + # get all parameters + bundles = [] + for interface in interfaces: + # convert parameters + parameters = [] + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: + tool_parameter = ToolParameter( + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), + human_description=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=parameter.get("required", False), + form=ToolParameter.ToolParameterForm.LLM, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) + if typ: + tool_parameter.type = typ + + parameters.append(tool_parameter) + # create tool bundle + # check if there is a request body + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): + # if there is a reference, get the reference and overwrite the content + if "schema" not in content: + continue + + if "$ref" in content["schema"]: + # get the reference + root = openapi + reference = content["schema"]["$ref"].split("/")[1:] + for ref in reference: + root = root[ref] + # overwrite the content + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root + + # parse body parameters + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) + for name, property in properties.items(): + tool = ToolParameter( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + human_description=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=name in required, + form=ToolParameter.ToolParameterForm.LLM, + llm_description=property.get("description", ""), + default=property.get("default", None), + placeholder=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) + if typ: + tool.type = typ + + parameters.append(tool) + + # check if parameters is duplicated + parameters_count = {} + for parameter in parameters: + if parameter.name not in parameters_count: + parameters_count[parameter.name] = 0 + parameters_count[parameter.name] += 1 + for name, count in parameters_count.items(): + if count > 1: + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." + + # check if there is a operation id, use $path_$method as operation id if not + if "operationId" not in interface["operation"]: + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) + if not path: + path = str(uuid.uuid4()) + + interface["operation"]["operationId"] = f"{path}_{interface['method']}" + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) + + return bundles + + @staticmethod + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + parameter = parameter or {} + typ: Optional[str] = None + if parameter.get("format") == "binary": + return ToolParameter.ToolParameterType.FILE + + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ in {"integer", "number"}: + return ToolParameter.ToolParameterType.NUMBER + elif typ == "boolean": + return ToolParameter.ToolParameterType.BOOLEAN + elif typ == "string": + return ToolParameter.ToolParameterType.STRING + elif typ == "array": + items = parameter.get("items") or parameter.get("schema", {}).get("items") + return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None + else: + return None + + @staticmethod + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + """ + parse openapi yaml to tool bundle + + :param yaml: the yaml string + :param extra_info: the extra info + :param warning: the warning message + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + openapi: dict = safe_load(yaml) + if openapi is None: + raise ToolApiSchemaError("Invalid openapi yaml.") + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: + warning = warning or {} + """ + parse swagger to openapi + + :param swagger: the swagger dict + :return: the openapi dict + """ + # convert swagger to openapi + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) + + servers = swagger.get("servers", []) + + if len(servers) == 0: + raise ToolApiSchemaError("No server found in the swagger yaml.") + + openapi = { + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), + }, + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, + } + + # check paths + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") + + # convert paths + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} + for method, operation in path_item.items(): + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), + } + + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] + + # convert definitions + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition + + return openapi + + @staticmethod + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict | None = None, warning: dict | None = None + ) -> list[ApiToolBundle]: + """ + parse openapi plugin yaml to tool bundle + + :param json: the json string + :param extra_info: the extra info + :param warning: the warning message + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + try: + openai_plugin = json_loads(json) + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] + except JSONDecodeError: + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + + # get openapi yaml + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) + + if response.status_code != 200: + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + + @staticmethod + def auto_parse_to_tool_bundle( + content: str, extra_info: dict | None = None, warning: dict | None = None + ) -> tuple[list[ApiToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :param extra_info: the extra info + :param warning: the warning message + :return: tools bundle, schema_type + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + content = content.strip() + loaded_content = None + json_error = None + yaml_error = None + + try: + loaded_content = json_loads(content) + except JSONDecodeError as e: + json_error = e + + if loaded_content is None: + try: + loaded_content = safe_load(content) + except YAMLError as e: + yaml_error = e + if loaded_content is None: + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," + f" yaml error: {str(yaml_error)}" + ) + + swagger_error = None + openapi_error = None + openapi_plugin_error = None + schema_type = None + + try: + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) + schema_type = ApiProviderSchemaType.OPENAPI.value + return openapi, schema_type + except ToolApiSchemaError as e: + openapi_error = e + + # openai parse error, fallback to swagger + try: + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) + schema_type = ApiProviderSchemaType.SWAGGER.value + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type + except ToolApiSchemaError as e: + swagger_error = e + + # swagger parse error, fallback to openai plugin + try: + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) + return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value + except ToolNotSupportedError as e: + # maybe it's not plugin at all + openapi_plugin_error = e + + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," + f" openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/datasource/utils/rag_web_reader.py b/api/core/datasource/utils/rag_web_reader.py new file mode 100644 index 0000000000..22c47fa814 --- /dev/null +++ b/api/core/datasource/utils/rag_web_reader.py @@ -0,0 +1,17 @@ +import re + + +def get_image_upload_file_ids(content): + pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + matches = re.findall(pattern, content) + image_upload_file_ids = [] + for match in matches: + if match[1] == "file-preview": + content_pattern = r"files/([^/]+)/file-preview" + else: + content_pattern = r"files/([^/]+)/image-preview" + content_match = re.search(content_pattern, match[0]) + if content_match: + image_upload_file_id = content_match.group(1) + image_upload_file_ids.append(image_upload_file_id) + return image_upload_file_ids diff --git a/api/core/datasource/utils/text_processing_utils.py b/api/core/datasource/utils/text_processing_utils.py new file mode 100644 index 0000000000..105823f896 --- /dev/null +++ b/api/core/datasource/utils/text_processing_utils.py @@ -0,0 +1,17 @@ +import re + + +def remove_leading_symbols(text: str) -> str: + """ + Remove leading punctuation or symbols from the given text. + + Args: + text (str): The input text to process. + + Returns: + str: The text with leading punctuation or symbols removed. + """ + # Match Unicode ranges for punctuation and symbols + # FIXME this pattern is confused quick fix for #11868 maybe refactor it later + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + return re.sub(pattern, "", text) diff --git a/api/core/datasource/utils/uuid_utils.py b/api/core/datasource/utils/uuid_utils.py new file mode 100644 index 0000000000..3046c08c89 --- /dev/null +++ b/api/core/datasource/utils/uuid_utils.py @@ -0,0 +1,9 @@ +import uuid + + +def is_valid_uuid(uuid_str: str) -> bool: + try: + uuid.UUID(uuid_str) + return True + except Exception: + return False diff --git a/api/core/datasource/utils/web_reader_tool.py b/api/core/datasource/utils/web_reader_tool.py new file mode 100644 index 0000000000..d42fd99fce --- /dev/null +++ b/api/core/datasource/utils/web_reader_tool.py @@ -0,0 +1,375 @@ +import hashlib +import json +import mimetypes +import os +import re +import site +import subprocess +import tempfile +import unicodedata +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Literal, Optional, cast +from urllib.parse import unquote + +import chardet +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore + +from core.helper import ssrf_proxy +from core.rag.extractor import extract_processor +from core.rag.extractor.extract_processor import ExtractProcessor + +FULL_TEMPLATE = """ +TITLE: {title} +AUTHORS: {authors} +PUBLISH DATE: {publish_date} +TOP_IMAGE_URL: {top_image} +TEXT: + +{text} +""" + + +def page_result(text: str, cursor: int, max_length: int) -> str: + """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" + return text[cursor : cursor + max_length] + + +def get_url(url: str, user_agent: Optional[str] = None) -> str: + """Fetch URL and return the contents as a string.""" + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" + } + if user_agent: + headers["User-Agent"] = user_agent + + main_content_type = None + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) + + if response.status_code == 200: + # check content-type + content_type = response.headers.get("Content-Type") + if content_type: + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() + else: + content_disposition = response.headers.get("Content-Disposition", "") + filename_match = re.search(r'filename="([^"]+)"', content_disposition) + if filename_match: + filename = unquote(filename_match.group(1)) + extension = re.search(r"\.(\w+)$", filename) + if extension: + main_content_type = mimetypes.guess_type(filename)[0] + + if main_content_type not in supported_content_types: + return "Unsupported content-type [{}] of URL.".format(main_content_type) + + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) + + response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + elif response.status_code == 403: + scraper = cloudscraper.create_scraper() + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + + if response.status_code != 200: + return "URL returned status code {}.".format(response.status_code) + + # Detect encoding using chardet + detected_encoding = chardet.detect(response.content) + encoding = detected_encoding["encoding"] + if encoding: + try: + content = response.content.decode(encoding) + except (UnicodeDecodeError, TypeError): + content = response.text + else: + content = response.text + + a = extract_using_readabilipy(content) + + if not a["plain_text"] or not a["plain_text"].strip(): + return "" + + res = FULL_TEMPLATE.format( + title=a["title"], + authors=a["byline"], + publish_date=a["date"], + top_image="", + text=a["plain_text"] or "", + ) + + return res + + +def extract_using_readabilipy(html): + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: + f_html.write(html) + f_html.close() + html_path = f_html.name + + # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file + article_json_path = html_path + ".json" + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") + with chdir(jsdir): + subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) + + # Read output of call to Readability.parse() from JSON file and return as Python dictionary + input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) + + # Deleting files after processing + os.unlink(article_json_path) + os.unlink(html_path) + + article_json: dict[str, Any] = { + "title": None, + "byline": None, + "date": None, + "content": None, + "plain_content": None, + "plain_text": None, + } + # Populate article fields from readability fields where present + if input_json: + if input_json.get("title"): + article_json["title"] = input_json["title"] + if input_json.get("byline"): + article_json["byline"] = input_json["byline"] + if input_json.get("date"): + article_json["date"] = input_json["date"] + if input_json.get("content"): + article_json["content"] = input_json["content"] + article_json["plain_content"] = plain_content(article_json["content"], False, False) + article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) + if input_json.get("textContent"): + article_json["plain_text"] = input_json["textContent"] + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) + + return article_json + + +def find_module_path(module_name): + for package_path in site.getsitepackages(): + potential_path = os.path.join(package_path, module_name) + if os.path.exists(potential_path): + return potential_path + + return None + + +@contextmanager +def chdir(path): + """Change directory in context and return to original on exit""" + # From https://stackoverflow.com/a/37996581, couldn't find a built-in + original_path = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(original_path) + + +def extract_text_blocks_as_plain_text(paragraph_html): + # Load article as DOM + soup = BeautifulSoup(paragraph_html, "html.parser") + # Select all lists + list_elements = soup.find_all(["ul", "ol"]) + # Prefix text in all list items with "* " and make lists paragraphs + for list_element in list_elements: + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) + list_element.string = plain_items + list_element.name = "p" + # Select all text blocks + text_blocks = [s.parent for s in soup.find_all(string=True)] + text_blocks = [plain_text_leaf_node(block) for block in text_blocks] + # Drop empty paragraphs + text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) + return text_blocks + + +def plain_text_leaf_node(element): + # Extract all text, stripped of any child HTML elements and normalize it + plain_text = normalize_text(element.get_text()) + if plain_text != "" and element.name == "li": + plain_text = "* {}, ".format(plain_text) + if plain_text == "": + plain_text = None + if "data-node-index" in element.attrs: + plain = {"node_index": element["data-node-index"], "text": plain_text} + else: + plain = {"text": plain_text} + return plain + + +def plain_content(readability_content, content_digests, node_indexes): + # Load article as DOM + soup = BeautifulSoup(readability_content, "html.parser") + # Make all elements plain + elements = plain_elements(soup.contents, content_digests, node_indexes) + if node_indexes: + # Add node index attributes to nodes + elements = [add_node_indexes(element) for element in elements] + # Replace article contents with plain elements + soup.contents = elements + return str(soup) + + +def plain_elements(elements, content_digests, node_indexes): + # Get plain content versions of all elements + elements = [plain_element(element, content_digests, node_indexes) for element in elements] + if content_digests: + # Add content digest attribute to nodes + elements = [add_content_digest(element) for element in elements] + return elements + + +def plain_element(element, content_digests, node_indexes): + # For lists, we make each item plain text + if is_leaf(element): + # For leaf node elements, extract the text content, discarding any HTML tags + # 1. Get element contents as text + plain_text = element.get_text() + # 2. Normalize the extracted text string to a canonical representation + plain_text = normalize_text(plain_text) + # 3. Update element content to be plain text + element.string = plain_text + elif is_text(element): + if is_non_printing(element): + # The simplified HTML may have come from Readability.js so might + # have non-printing text (e.g. Comment or CData). In this case, we + # keep the structure, but ensure that the string is empty. + element = type(element)("") + else: + plain_text = element.string + plain_text = normalize_text(plain_text) + element = type(element)(plain_text) + else: + # If not a leaf node or leaf type call recursively on child nodes, replacing + element.contents = plain_elements(element.contents, content_digests, node_indexes) + return element + + +def add_node_indexes(element, node_index="0"): + # Can't add attributes to string types + if is_text(element): + return element + # Add index to current element + element["data-node-index"] = node_index + # Add index to child elements + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): + # Can't add attributes to leaf string types + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) + add_node_indexes(child, node_index=child_index) + return element + + +def normalize_text(text): + """Normalize unicode and whitespace.""" + # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them + text = strip_control_characters(text) + text = normalize_unicode(text) + text = normalize_whitespace(text) + return text + + +def strip_control_characters(text): + """Strip out unicode control characters which might break the parsing.""" + # Unicode control characters + # [Cc]: Other, Control [includes new lines] + # [Cf]: Other, Format + # [Cn]: Other, Not Assigned + # [Co]: Other, Private Use + # [Cs]: Other, Surrogate + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] + + # Remove non-printing control characters + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) + + +def normalize_unicode(text): + """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" + text = unicodedata.normalize(normal_form, text) + return text + + +def normalize_whitespace(text): + """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" + text = regex.sub(r"\s+", " ", text) + # Remove leading and trailing whitespace + text = text.strip() + return text + + +def is_leaf(element): + return element.name in {"p", "li"} + + +def is_text(element): + return isinstance(element, NavigableString) + + +def is_non_printing(element): + return any(isinstance(element, _e) for _e in [Comment, CData]) + + +def add_content_digest(element): + if not is_text(element): + element["data-content-digest"] = content_digest(element) + return element + + +def content_digest(element): + digest: Any + if is_text(element): + # Hash + trimmed_string = element.string.strip() + if trimmed_string == "": + digest = "" + else: + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() + else: + contents = element.contents + num_contents = len(contents) + if num_contents == 0: + # No hash when no child elements exist + digest = "" + elif num_contents == 1: + # If single child, use digest of child + digest = content_digest(contents[0]) + else: + # Build content digest from the "non-empty" digests of child nodes + digest = hashlib.sha256() + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) + for child in child_digests: + digest.update(child.encode("utf-8")) + digest = digest.hexdigest() + return digest + + +def get_image_upload_file_ids(content): + pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + matches = re.findall(pattern, content) + image_upload_file_ids = [] + for match in matches: + if match[1] == "file-preview": + content_pattern = r"files/([^/]+)/file-preview" + else: + content_pattern = r"files/([^/]+)/image-preview" + content_match = re.search(content_pattern, match[0]) + if content_match: + image_upload_file_id = content_match.group(1) + image_upload_file_ids.append(image_upload_file_id) + return image_upload_file_ids diff --git a/api/core/datasource/utils/workflow_configuration_sync.py b/api/core/datasource/utils/workflow_configuration_sync.py new file mode 100644 index 0000000000..d16d6fc576 --- /dev/null +++ b/api/core/datasource/utils/workflow_configuration_sync.py @@ -0,0 +1,43 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.app.app_config.entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration + + +class WorkflowToolConfigurationUtils: + @classmethod + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): + for configuration in configurations: + WorkflowToolParameterConfiguration.model_validate(configuration) + + @classmethod + def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: + """ + get workflow graph variables + """ + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) + + if not start_node: + return [] + + return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] + + @classmethod + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ): + """ + check is synced + + raise ValueError if not synced + """ + variable_names = [variable.variable for variable in variables] + + if len(tool_configurations) != len(variables): + raise ValueError("parameter configuration mismatch, please republish the tool to update") + + for parameter in tool_configurations: + if parameter.name not in variable_names: + raise ValueError("parameter configuration mismatch, please republish the tool to update") diff --git a/api/core/datasource/utils/yaml_utils.py b/api/core/datasource/utils/yaml_utils.py new file mode 100644 index 0000000000..ee7ca11e05 --- /dev/null +++ b/api/core/datasource/utils/yaml_utils.py @@ -0,0 +1,35 @@ +import logging +from pathlib import Path +from typing import Any + +import yaml # type: ignore +from yaml import YAMLError + +logger = logging.getLogger(__name__) + + +def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: + """ + Safe loading a YAML file + :param file_path: the path of the YAML file + :param ignore_error: + if True, return default_value if error occurs and the error will be logged in debug level + if False, raise error if error occurs + :param default_value: the value returned when errors ignored + :return: an object of the YAML content + """ + if not file_path or not Path(file_path).exists(): + if ignore_error: + return default_value + else: + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, encoding="utf-8") as yaml_file: + try: + yaml_content = yaml.safe_load(yaml_file) + return yaml_content or default_value + except Exception as e: + if ignore_error: + return default_value + else: + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py index e3fe17c284..e5deafc32f 100644 --- a/api/core/workflow/constants.py +++ b/api/core/workflow/constants.py @@ -1,3 +1,4 @@ SYSTEM_VARIABLE_NODE_ID = "sys" ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" +PIPELINE_VARIABLE_NODE_ID = "pipeline" \ No newline at end of file diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 82fd6cdc30..ecd2cfeabc 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum): TOTAL_PRICE = "total_price" CURRENCY = "currency" TOOL_INFO = "tool_info" + DATASOURCE_INFO = "datasource_info" AGENT_LOG = "agent_log" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py new file mode 100644 index 0000000000..cee9e5a895 --- /dev/null +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -0,0 +1,3 @@ +from .tool_node import ToolNode + +__all__ = ["DatasourceNode"] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py new file mode 100644 index 0000000000..1752ba36fa --- /dev/null +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -0,0 +1,406 @@ +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.file import File, FileTransferMethod +from core.plugin.manager.exc import PluginDaemonClientSideError +from core.plugin.manager.plugin import PluginInstallationManager +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.variables.segments import ArrayAnySegment +from core.variables.variables import ArrayAnyVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import AgentLogEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from models.workflow import WorkflowNodeExecutionStatus +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .entities import DatasourceNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) + + +class DatasourceNode(BaseNode[DatasourceNodeData]): + """ + Datasource Node + """ + + _node_data_cls = DatasourceNodeData + _node_type = NodeType.DATASOURCE + + def _run(self) -> Generator: + """ + Run the datasource node + """ + + node_data = cast(DatasourceNodeData, self.node_data) + + # fetch datasource icon + datasource_info = { + "provider_type": node_data.provider_type.value, + "provider_id": node_data.provider_id, + "plugin_unique_identifier": node_data.plugin_unique_identifier, + } + + # get datasource runtime + try: + from core.tools.tool_manager import ToolManager + + tool_runtime = ToolManager.get_workflow_tool_runtime( + self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from + ) + except ToolNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) + ) + return + + # get parameters + tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + parameters = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + for_log=True, + ) + + # get conversation id + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + + try: + message_stream = ToolEngine.generic_invoke( + tool=tool_runtime, + tool_parameters=parameters, + user_id=self.user_id, + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, + app_id=self.app_id, + conversation_id=conversation_id.text if conversation_id else None, + ) + except ToolNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", + error_type=type(e).__name__, + ) + ) + return + + try: + # convert tool messages + yield from self._transform_message(message_stream, tool_info, parameters_for_log) + except (PluginDaemonClientSideError, ToolInvokeError) as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to transform tool message: {str(e)}", + error_type=type(e).__name__, + ) + ) + + def _generate_parameters( + self, + *, + tool_parameters: Sequence[ToolParameter], + variable_pool: VariablePool, + node_data: ToolNodeData, + for_log: bool = False, + ) -> dict[str, Any]: + """ + Generate parameters based on the given tool parameters, variable pool, and node data. + + Args: + tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (ToolNodeData): The data associated with the tool node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.tool_parameters: + parameter = tool_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + tool_input = node_data.tool_parameters[parameter_name] + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + raise ToolParameterError(f"Variable {tool_input.value} does not exist") + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value + + return result + + def _fetch_files(self, variable_pool: VariablePool) -> list[File]: + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) + return list(variable.value) if variable else [] + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"Tool file {tool_file_id} does not exist") + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"tool file {tool_file_id} not exists") + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent( + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if self.node_type == NodeType.AGENT: + msg_metadata = message.message.json_object.pop("execution_metadata", {}) + agent_execution_metadata = { + key: value + for key, value in msg_metadata.items() + if key in NodeRunMetadataKey.__members__.values() + } + json.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstallationManager() + plugins = manager.list_plugins(self.tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self.user_id, + self.tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + except StopIteration: + pass + + dict_metadata["icon"] = icon + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + id=message.message.id, + node_execution_id=self.id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=self.node_id, + ) + + # check if the agent log is already in the list + for log in agent_logs: + if log.id == agent_log.id: + # update the log + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"text": text, "files": files, "json": json, **variables}, + metadata={ + **agent_execution_metadata, + NodeRunMetadataKey.TOOL_INFO: tool_info, + NodeRunMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + ) + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + result = {} + for parameter_name in node_data.tool_parameters: + input = node_data.tool_parameters[parameter_name] + if input.type == "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + elif input.type == "variable": + result[parameter_name] = input.value + elif input.type == "constant": + pass + + result = {node_id + "." + key: value for key, value in result.items()} + + return result diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py new file mode 100644 index 0000000000..66e8adc431 --- /dev/null +++ b/api/core/workflow/nodes/datasource/entities.py @@ -0,0 +1,56 @@ +from typing import Any, Literal, Union + +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.nodes.base.entities import BaseNodeData + + +class DatasourceEntity(BaseModel): + provider_id: str + provider_type: ToolProviderType + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_configurations: dict[str, Any] + plugin_unique_identifier: str | None = None # redundancy + + @field_validator("tool_configurations", mode="before") + @classmethod + def validate_tool_configurations(cls, value, values: ValidationInfo): + if not isinstance(value, dict): + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}): + value = values.data.get("tool_configurations", {}).get(key) + if not isinstance(value, str | int | float | bool): + raise ValueError(f"{key} must be a string") + + return value + + +class DatasourceNodeData(BaseNodeData, DatasourceEntity): + class DatasourceInput(BaseModel): + # TODO: check this type + value: Union[Any, list[str]] + type: Literal["mixed", "variable", "constant"] + + @field_validator("type", mode="before") + @classmethod + def check_type(cls, value, validation_info: ValidationInfo): + typ = value + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": + if not isinstance(value, list): + raise ValueError("value must be a list") + for val in value: + if not isinstance(val, str): + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") + return typ + + datasource_parameters: dict[str, DatasourceInput] diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py new file mode 100644 index 0000000000..7212e8bfc0 --- /dev/null +++ b/api/core/workflow/nodes/datasource/exc.py @@ -0,0 +1,16 @@ +class ToolNodeError(ValueError): + """Base exception for tool node errors.""" + + pass + + +class ToolParameterError(ToolNodeError): + """Exception raised for errors in tool parameters.""" + + pass + + +class ToolFileError(ToolNodeError): + """Exception raised for errors related to tool files.""" + + pass diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..673d0ba049 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -13,6 +13,7 @@ class NodeType(StrEnum): QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" TOOL = "tool" + DATASOURCE = "datasource" VARIABLE_AGGREGATOR = "variable-aggregator" LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. LOOP = "loop" diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6f0cc3f6d2..08af6b0014 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -73,7 +73,7 @@ class ToolNode(BaseNode[ToolNodeData]): metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to get tool runtime: {str(e)}", error_type=type(e).__name__, - ) + ) ) return diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index bbca8448ec..bb6b366a81 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -36,7 +36,11 @@ from core.variables.variables import ( StringVariable, Variable, ) -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + PIPELINE_VARIABLE_NODE_ID, +) class InvalidSelectorError(ValueError): @@ -74,6 +78,10 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if not mapping.get("name"): + raise VariableError("missing name") + return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]]) def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: """ diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 971e99c259..1bf70da9d9 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -40,6 +40,13 @@ conversation_variable_fields = { "description": fields.String, } +pipeline_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, +} + workflow_fields = { "id": fields.String, "graph": fields.Raw(attribute="graph_dict"), @@ -55,6 +62,10 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), + "pipeline_variables": fields.Dict( + keys=fields.String, + values=fields.List(fields.Nested(pipeline_variable_fields)), + ), } workflow_partial_fields = { diff --git a/api/models/workflow.py b/api/models/workflow.py index c85f335f37..2e9f6f0315 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -130,6 +130,9 @@ class Workflow(Base): _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) + _pipeline_variables: Mapped[str] = mapped_column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) @classmethod def new( @@ -343,6 +346,24 @@ class Workflow(Base): ensure_ascii=False, ) + @property + def pipeline_variables(self) -> dict[str, Sequence[Variable]]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._pipeline_variables is None: + self._pipeline_variables = "{}" + + variables_dict: dict[str, Any] = json.loads(self._pipeline_variables) + results = {} + for k, v in variables_dict.items(): + results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] + return results + + @pipeline_variables.setter + def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: + self._pipeline_variables = json.dumps( + {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, + ensure_ascii=False, + ) class WorkflowRunStatus(StrEnum): """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index c2c9c56e9d..422f24d521 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -16,12 +16,13 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry -from extensions.db import db +from extensions.ext_database import db from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError +from services.errors.workflow_service import DraftWorkflowDeletionError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -186,6 +187,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + pipeline_variables: dict[str, Sequence[Variable]], ) -> Workflow: """ Sync draft workflow @@ -212,6 +214,7 @@ class RagPipelineService: created_by=account.id, environment_variables=environment_variables, conversation_variables=conversation_variables, + pipeline_variables=pipeline_variables, ) db.session.add(workflow) # update draft workflow if found @@ -222,7 +225,7 @@ class RagPipelineService: workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables - + workflow.pipeline_variables = pipeline_variables # commit db session changes db.session.commit() @@ -342,6 +345,41 @@ class RagPipelineService: db.session.commit() return workflow_node_execution + + def run_datasource_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run published workflow datasource + """ + # fetch published workflow by app_model + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.single_step_run( + workflow=published_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ), + start_at=start_at, + tenant_id=pipeline.tenant_id, + node_id=node_id, + ) + + workflow_node_execution.app_id = pipeline.id + workflow_node_execution.created_by = account.id + workflow_node_execution.workflow_id = published_workflow.id + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] @@ -573,3 +611,21 @@ class RagPipelineService: session.delete(workflow) return True + + def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: + """ + Get second step parameters of rag pipeline + """ + + workflow = self.get_published_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # get second step node + pipeline_variables = workflow.pipeline_variables + if not pipeline_variables: + return {} + # get datasource provider + datasource_provider_variables = pipeline_variables.get(datasource_provider, []) + shared_variables = pipeline_variables.get("shared", []) + return datasource_provider_variables + shared_variables \ No newline at end of file From c7f4b4192027b74e5673dbdaf3e24deda6529d42 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 22 Apr 2025 16:08:58 +0800 Subject: [PATCH 005/155] r2 --- .../provider.py | 39 ++-- .../{plugin_tool => datasource_tool}/tool.py | 50 +++- .../datasource/entities/agent_entities.py | 0 api/core/datasource/entities/api_entities.py | 72 ------ ...ool_entities.py => datasource_entities.py} | 85 ++++--- api/core/plugin/manager/datasource.py | 217 ++++++++++++++++++ api/core/plugin/utils/converter.py | 7 + 7 files changed, 321 insertions(+), 149 deletions(-) rename api/core/datasource/{plugin_tool => datasource_tool}/provider.py (59%) rename api/core/datasource/{plugin_tool => datasource_tool}/tool.py (59%) delete mode 100644 api/core/datasource/entities/agent_entities.py delete mode 100644 api/core/datasource/entities/api_entities.py rename api/core/datasource/entities/{tool_entities.py => datasource_entities.py} (84%) create mode 100644 api/core/plugin/manager/datasource.py diff --git a/api/core/datasource/plugin_tool/provider.py b/api/core/datasource/datasource_tool/provider.py similarity index 59% rename from api/core/datasource/plugin_tool/provider.py rename to api/core/datasource/datasource_tool/provider.py index 3616e426b9..3104728947 100644 --- a/api/core/datasource/plugin_tool/provider.py +++ b/api/core/datasource/datasource_tool/provider.py @@ -1,21 +1,22 @@ from typing import Any +from core.datasource.datasource_tool.tool import DatasourceTool +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.plugin.manager.tool import PluginToolManager from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.plugin_tool.tool import PluginTool -class PluginToolProviderController(BuiltinToolProviderController): - entity: ToolProviderEntityWithPlugin +class DatasourceToolProviderController(BuiltinToolProviderController): + entity: DatasourceProviderEntityWithPlugin tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( - self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: self.entity = entity self.tenant_id = tenant_id @@ -23,13 +24,13 @@ class PluginToolProviderController(BuiltinToolProviderController): self.plugin_unique_identifier = plugin_unique_identifier @property - def provider_type(self) -> ToolProviderType: + def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider :return: type of the provider """ - return ToolProviderType.PLUGIN + return DatasourceProviderType.RAG_PIPELINE def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ @@ -44,36 +45,36 @@ class PluginToolProviderController(BuiltinToolProviderController): ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_tool(self, tool_name: str) -> PluginTool: # type: ignore + def get_datasource(self, datasource_name: str) -> DatasourceTool: # type: ignore """ - return tool with given name + return datasource with given name """ - tool_entity = next( - (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None + datasource_entity = next( + (datasource_entity for datasource_entity in self.entity.datasources if datasource_entity.identity.name == datasource_name), None ) - if not tool_entity: - raise ValueError(f"Tool with name {tool_name} not found") + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") - return PluginTool( - entity=tool_entity, + return DatasourceTool( + entity=datasource_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, icon=self.entity.identity.icon, plugin_unique_identifier=self.plugin_unique_identifier, ) - def get_tools(self) -> list[PluginTool]: # type: ignore + def get_datasources(self) -> list[DatasourceTool]: # type: ignore """ - get all tools + get all datasources """ return [ - PluginTool( - entity=tool_entity, + DatasourceTool( + entity=datasource_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, icon=self.entity.identity.icon, plugin_unique_identifier=self.plugin_unique_identifier, ) - for tool_entity in self.entity.tools + for datasource_entity in self.entity.datasources ] diff --git a/api/core/datasource/plugin_tool/tool.py b/api/core/datasource/datasource_tool/tool.py similarity index 59% rename from api/core/datasource/plugin_tool/tool.py rename to api/core/datasource/datasource_tool/tool.py index f31a9a0d3e..b69b2368a4 100644 --- a/api/core/datasource/plugin_tool/tool.py +++ b/api/core/datasource/datasource_tool/tool.py @@ -1,6 +1,7 @@ from collections.abc import Generator from typing import Any, Optional +from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType from core.plugin.manager.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.__base.tool import Tool @@ -8,14 +9,14 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType -class PluginTool(Tool): +class DatasourceTool(Tool): tenant_id: str icon: str plugin_unique_identifier: str - runtime_parameters: Optional[list[ToolParameter]] + runtime_parameters: Optional[list[DatasourceParameter]] def __init__( - self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str + self, entity: DatasourceEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str ) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id @@ -23,20 +24,44 @@ class PluginTool(Tool): self.plugin_unique_identifier = plugin_unique_identifier self.runtime_parameters = None - def tool_provider_type(self) -> ToolProviderType: - return ToolProviderType.PLUGIN + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.RAG_PIPELINE - def _invoke( + def _invoke_first_step( self, user_id: str, - tool_parameters: dict[str, Any], + datasource_parameters: dict[str, Any], conversation_id: Optional[str] = None, - app_id: Optional[str] = None, + rag_pipeline_id: Optional[str] = None, message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() - tool_parameters = convert_parameters_to_plugin_format(tool_parameters) + datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) + + yield from manager.invoke_first_step( + tenant_id=self.tenant_id, + user_id=user_id, + tool_provider=self.entity.identity.provider, + tool_name=self.entity.identity.name, + credentials=self.runtime.credentials, + tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + + def _invoke_second_step( + self, + user_id: str, + datasource_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + rag_pipeline_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + manager = PluginToolManager() + + datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) yield from manager.invoke( tenant_id=self.tenant_id, @@ -50,8 +75,9 @@ class PluginTool(Tool): message_id=message_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": - return PluginTool( + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool": + return DatasourceTool( entity=self.entity, runtime=runtime, tenant_id=self.tenant_id, @@ -64,7 +90,7 @@ class PluginTool(Tool): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, - ) -> list[ToolParameter]: + ) -> list[DatasourceParameter]: """ get the runtime parameters """ diff --git a/api/core/datasource/entities/agent_entities.py b/api/core/datasource/entities/agent_entities.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py deleted file mode 100644 index b96c994cff..0000000000 --- a/api/core/datasource/entities/api_entities.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Literal, Optional - -from pydantic import BaseModel, Field, field_validator - -from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.__base.tool import ToolParameter -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType - - -class ToolApiEntity(BaseModel): - author: str - name: str # identifier - label: I18nObject # label - description: I18nObject - parameters: Optional[list[ToolParameter]] = None - labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None - - -ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] - - -class ToolProviderApiEntity(BaseModel): - id: str - author: str - name: str # identifier - description: I18nObject - icon: str | dict - label: I18nObject # label - type: ToolProviderType - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None - is_team_authorization: bool = False - allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") - tools: list[ToolApiEntity] = Field(default_factory=list) - labels: list[str] = Field(default_factory=list) - - @field_validator("tools", mode="before") - @classmethod - def convert_none_to_empty_list(cls, v): - return v if v is not None else [] - - def to_dict(self) -> dict: - # ------------- - # overwrite tool parameter types for temp fix - tools = jsonable_encoder(self.tools) - for tool in tools: - if tool.get("parameters"): - for parameter in tool.get("parameters"): - if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: - parameter["type"] = "files" - # ------------- - - return { - "id": self.id, - "author": self.author, - "name": self.name, - "plugin_id": self.plugin_id, - "plugin_unique_identifier": self.plugin_unique_identifier, - "description": self.description.to_dict(), - "icon": self.icon, - "label": self.label.to_dict(), - "type": self.type.value, - "team_credentials": self.masked_credentials, - "is_team_authorization": self.is_team_authorization, - "allow_delete": self.allow_delete, - "tools": tools, - "labels": self.labels, - } diff --git a/api/core/datasource/entities/tool_entities.py b/api/core/datasource/entities/datasource_entities.py similarity index 84% rename from api/core/datasource/entities/tool_entities.py rename to api/core/datasource/entities/datasource_entities.py index d756763137..39c28c0d7d 100644 --- a/api/core/datasource/entities/tool_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -38,20 +38,15 @@ class ToolLabelEnum(Enum): OTHER = "other" -class ToolProviderType(enum.StrEnum): +class DatasourceProviderType(enum.StrEnum): """ - Enum class for tool provider + Enum class for datasource provider """ - PLUGIN = "plugin" - BUILT_IN = "builtin" - WORKFLOW = "workflow" - API = "api" - APP = "app" - DATASET_RETRIEVAL = "dataset-retrieval" + RAG_PIPELINE = "rag_pipeline" @classmethod - def value_of(cls, value: str) -> "ToolProviderType": + def value_of(cls, value: str) -> "DatasourceProviderType": """ Get value of given mode. @@ -211,12 +206,12 @@ class ToolInvokeMessageBinary(BaseModel): file_var: Optional[dict[str, Any]] = None -class ToolParameter(PluginParameter): +class DatasourceParameter(PluginParameter): """ Overrides type """ - class ToolParameterType(enum.StrEnum): + class DatasourceParameterType(enum.StrEnum): """ removes TOOLS_SELECTOR from PluginParameterType """ @@ -240,14 +235,14 @@ class ToolParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class ToolParameterForm(Enum): + class DatasourceParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool FORM = "form" # should be set before invoking tool LLM = "llm" # will be set by LLM - type: ToolParameterType = Field(..., description="The type of the parameter") + type: DatasourceParameterType = Field(..., description="The type of the parameter") human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") - form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") + form: DatasourceParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None @classmethod @@ -255,12 +250,12 @@ class ToolParameter(PluginParameter): cls, name: str, llm_description: str, - typ: ToolParameterType, + typ: DatasourceParameterType, required: bool, options: Optional[list[str]] = None, - ) -> "ToolParameter": + ) -> "DatasourceParameter": """ - get a simple tool parameter + get a simple datasource parameter :param name: the name of the parameter :param llm_description: the description presented to the LLM @@ -306,7 +301,7 @@ class ToolProviderIdentity(BaseModel): ) -class ToolIdentity(BaseModel): +class DatasourceIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") @@ -314,15 +309,15 @@ class ToolIdentity(BaseModel): icon: Optional[str] = None -class ToolDescription(BaseModel): +class DatasourceDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") -class ToolEntity(BaseModel): - identity: ToolIdentity - parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None +class DatasourceEntity(BaseModel): + identity: DatasourceIdentity + parameters: list[DatasourceParameter] = Field(default_factory=list) + description: Optional[DatasourceDescription] = None output_schema: Optional[dict] = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") @@ -331,7 +326,7 @@ class ToolEntity(BaseModel): @field_validator("parameters", mode="before") @classmethod - def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: return v or [] @@ -341,8 +336,8 @@ class ToolProviderEntity(BaseModel): credentials_schema: list[ProviderConfig] = Field(default_factory=list) -class ToolProviderEntityWithPlugin(ToolProviderEntity): - tools: list[ToolEntity] = Field(default_factory=list) +class DatasourceProviderEntityWithPlugin(ToolProviderEntity): + datasources: list[DatasourceEntity] = Field(default_factory=list) class WorkflowToolParameterConfiguration(BaseModel): @@ -352,12 +347,12 @@ class WorkflowToolParameterConfiguration(BaseModel): name: str = Field(..., description="The name of the parameter") description: str = Field(..., description="The description of the parameter") - form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + form: DatasourceParameter.DatasourceParameterForm = Field(..., description="The form of the parameter") -class ToolInvokeMeta(BaseModel): +class DatasourceInvokeMeta(BaseModel): """ - Tool invoke meta + Datasource invoke meta """ time_cost: float = Field(..., description="The time cost of the tool invoke") @@ -365,16 +360,16 @@ class ToolInvokeMeta(BaseModel): tool_config: Optional[dict] = None @classmethod - def empty(cls) -> "ToolInvokeMeta": + def empty(cls) -> "DatasourceInvokeMeta": """ - Get an empty instance of ToolInvokeMeta + Get an empty instance of DatasourceInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "ToolInvokeMeta": + def error_instance(cls, error: str) -> "DatasourceInvokeMeta": """ - Get an instance of ToolInvokeMeta with error + Get an instance of DatasourceInvokeMeta with error """ return cls(time_cost=0.0, error=error, tool_config={}) @@ -386,9 +381,9 @@ class ToolInvokeMeta(BaseModel): } -class ToolLabel(BaseModel): +class DatasourceLabel(BaseModel): """ - Tool label + Datasource label """ name: str = Field(..., description="The name of the tool") @@ -396,32 +391,30 @@ class ToolLabel(BaseModel): icon: str = Field(..., description="The icon of the tool") -class ToolInvokeFrom(Enum): +class DatasourceInvokeFrom(Enum): """ - Enum class for tool invoke + Enum class for datasource invoke """ - WORKFLOW = "workflow" - AGENT = "agent" - PLUGIN = "plugin" + RAG_PIPELINE = "rag_pipeline" -class ToolSelector(BaseModel): +class DatasourceSelector(BaseModel): dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY class Parameter(BaseModel): name: str = Field(..., description="The name of the parameter") - type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") + type: DatasourceParameter.DatasourceParameterType = Field(..., description="The type of the parameter") required: bool = Field(..., description="Whether the parameter is required") description: str = Field(..., description="The description of the parameter") default: Optional[Union[int, float, str]] = None options: Optional[list[PluginParameterOption]] = None provider_id: str = Field(..., description="The id of the provider") - tool_name: str = Field(..., description="The name of the tool") - tool_description: str = Field(..., description="The description of the tool") - tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") - tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") + datasource_name: str = Field(..., description="The name of the datasource") + datasource_description: str = Field(..., description="The description of the datasource") + datasource_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") + datasource_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() diff --git a/api/core/plugin/manager/datasource.py b/api/core/plugin/manager/datasource.py new file mode 100644 index 0000000000..5a6f557e4b --- /dev/null +++ b/api/core/plugin/manager/datasource.py @@ -0,0 +1,217 @@ +from collections.abc import Generator +from typing import Any, Optional + +from pydantic import BaseModel + +from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.manager.base import BasePluginManager +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter + + +class PluginDatasourceManager(BasePluginManager): + def fetch_datasource_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: + """ + Fetch datasource providers for the given tenant. + """ + + def transformer(json_response: dict[str, Any]) -> dict: + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for tool in declaration.get("tools", []): + tool["identity"]["provider"] = provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + list[PluginToolProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + + for provider in response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.tools: + tool.identity.provider = provider.declaration.identity.name + + return response + + def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: + """ + Fetch datasource provider for the given tenant and plugin. + """ + tool_provider_id = ToolProviderID(provider) + + def transformer(json_response: dict[str, Any]) -> dict: + data = json_response.get("data") + if data: + for datasource in data.get("declaration", {}).get("datasources", []): + datasource["identity"]["provider"] = tool_provider_id.provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + PluginToolProviderEntity, + params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, + transformer=transformer, + ) + + response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in response.declaration.tools: + tool.identity.provider = response.declaration.identity.name + + return response + + def invoke_first_step( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: dict[str, Any], + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step", + ToolInvokeMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "datasource_parameters": datasource_parameters, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + return response + + def invoke_second_step( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: dict[str, Any], + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/invoke_second_step", + ToolInvokeMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "datasource_parameters": datasource_parameters, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + return response + + def validate_provider_credentials( + self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + ) -> bool: + """ + validate the credentials of the provider + """ + tool_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/tool/validate_credentials", + PluginBasicBooleanResponse, + data={ + "user_id": user_id, + "data": { + "provider": tool_provider_id.provider_name, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": tool_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp.result + + return False + + def get_runtime_parameters( + self, + tenant_id: str, + user_id: str, + provider: str, + credentials: dict[str, Any], + datasource: str, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: + """ + get the runtime parameters of the datasource + """ + datasource_provider_id = GenericProviderID(provider) + + class RuntimeParametersResponse(BaseModel): + parameters: list[ToolParameter] + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/get_runtime_parameters", + RuntimeParametersResponse, + data={ + "user_id": user_id, + "conversation_id": conversation_id, + "app_id": app_id, + "message_id": message_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp.parameters + + return [] diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 6876285b31..49bf7c308a 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,5 +1,6 @@ from typing import Any +from core.datasource.entities.datasource_entities import DatasourceSelector from core.file.models import File from core.tools.entities.tool_entities import ToolSelector @@ -18,4 +19,10 @@ def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, parameters[parameter_name] = [] for p in parameter: parameters[parameter_name].append(p.to_plugin_parameter()) + elif isinstance(parameter, DatasourceSelector): + parameters[parameter_name] = parameter.to_plugin_parameter() + elif isinstance(parameter, list) and all(isinstance(p, DatasourceSelector) for p in parameter): + parameters[parameter_name] = [] + for p in parameter: + parameters[parameter_name].append(p.to_plugin_parameter()) return parameters From b9ab1555fbf102b7955459340f6dbc4a7510a1f0 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 24 Apr 2025 15:42:30 +0800 Subject: [PATCH 006/155] r2 --- .../__base/{tool.py => datasource.py} | 25 +- ...{tool_runtime.py => datasource_runtime.py} | 13 +- .../{tool_engine.py => datasource_engine.py} | 4 +- api/core/datasource/datasource_tool/tool.py | 48 +-- .../dataset_multi_retriever_tool.py | 199 ---------- .../dataset_retriever_base_tool.py | 33 -- .../dataset_retriever_tool.py | 202 ---------- .../utils/dataset_retriever_tool.py | 134 ------- .../utils/model_invocation_utils.py | 169 -------- api/core/datasource/utils/rag_web_reader.py | 17 - api/core/datasource/utils/web_reader_tool.py | 375 ------------------ api/installed_plugins.jsonl | 1 + 12 files changed, 41 insertions(+), 1179 deletions(-) rename api/core/datasource/__base/{tool.py => datasource.py} (90%) rename api/core/datasource/__base/{tool_runtime.py => datasource_runtime.py} (66%) rename api/core/datasource/{tool_engine.py => datasource_engine.py} (99%) delete mode 100644 api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py delete mode 100644 api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py delete mode 100644 api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py delete mode 100644 api/core/datasource/utils/dataset_retriever_tool.py delete mode 100644 api/core/datasource/utils/model_invocation_utils.py delete mode 100644 api/core/datasource/utils/rag_web_reader.py delete mode 100644 api/core/datasource/utils/web_reader_tool.py create mode 100644 api/installed_plugins.jsonl diff --git a/api/core/datasource/__base/tool.py b/api/core/datasource/__base/datasource.py similarity index 90% rename from api/core/datasource/__base/tool.py rename to api/core/datasource/__base/datasource.py index 35e16b5c8f..3a67b56e32 100644 --- a/api/core/datasource/__base/tool.py +++ b/api/core/datasource/__base/datasource.py @@ -6,31 +6,30 @@ from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from models.model import File -from core.tools.__base.tool_runtime import ToolRuntime +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType from core.tools.entities.tool_entities import ( - ToolEntity, ToolInvokeMessage, ToolParameter, - ToolProviderType, ) -class Tool(ABC): +class Datasource(ABC): """ - The base class of a tool + The base class of a datasource """ - entity: ToolEntity - runtime: ToolRuntime + entity: DatasourceEntity + runtime: DatasourceRuntime - def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: + def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None: self.entity = entity self.runtime = runtime - def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource": """ - fork a new tool with metadata - :return: the new tool + fork a new datasource with metadata + :return: the new datasource """ return self.__class__( entity=self.entity.model_copy(), @@ -38,9 +37,9 @@ class Tool(ABC): ) @abstractmethod - def tool_provider_type(self) -> ToolProviderType: + def datasource_provider_type(self) -> DatasourceProviderType: """ - get the tool provider type + get the datasource provider type :return: the tool provider type """ diff --git a/api/core/datasource/__base/tool_runtime.py b/api/core/datasource/__base/datasource_runtime.py similarity index 66% rename from api/core/datasource/__base/tool_runtime.py rename to api/core/datasource/__base/datasource_runtime.py index c9e157cb77..51ff1fc6c1 100644 --- a/api/core/datasource/__base/tool_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -4,12 +4,13 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import DatasourceInvokeFrom from core.tools.entities.tool_entities import ToolInvokeFrom -class ToolRuntime(BaseModel): +class DatasourceRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a datasource call processing """ tenant_id: str @@ -20,17 +21,17 @@ class ToolRuntime(BaseModel): runtime_parameters: dict[str, Any] = Field(default_factory=dict) -class FakeToolRuntime(ToolRuntime): +class FakeDatasourceRuntime(DatasourceRuntime): """ - Fake tool runtime for testing + Fake datasource runtime for testing """ def __init__(self): super().__init__( tenant_id="fake_tenant_id", - tool_id="fake_tool_id", + datasource_id="fake_datasource_id", invoke_from=InvokeFrom.DEBUGGER, - tool_invoke_from=ToolInvokeFrom.AGENT, + datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, credentials={}, runtime_parameters={}, ) diff --git a/api/core/datasource/tool_engine.py b/api/core/datasource/datasource_engine.py similarity index 99% rename from api/core/datasource/tool_engine.py rename to api/core/datasource/datasource_engine.py index ad0c62537c..423f78a787 100644 --- a/api/core/datasource/tool_engine.py +++ b/api/core/datasource/datasource_engine.py @@ -36,9 +36,9 @@ from models.enums import CreatedByRole from models.model import Message, MessageFile -class ToolEngine: +class DatasourceEngine: """ - Tool runtime engine take care of the tool executions. + Datasource runtime engine take care of the datasource executions. """ @staticmethod diff --git a/api/core/datasource/datasource_tool/tool.py b/api/core/datasource/datasource_tool/tool.py index b69b2368a4..1c8572c2c5 100644 --- a/api/core/datasource/datasource_tool/tool.py +++ b/api/core/datasource/datasource_tool/tool.py @@ -1,7 +1,9 @@ from collections.abc import Generator from typing import Any, Optional +from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType +from core.plugin.manager.datasource import PluginDatasourceManager from core.plugin.manager.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.__base.tool import Tool @@ -9,7 +11,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType -class DatasourceTool(Tool): +class DatasourcePlugin(Datasource): tenant_id: str icon: str plugin_unique_identifier: str @@ -31,53 +33,45 @@ class DatasourceTool(Tool): self, user_id: str, datasource_parameters: dict[str, Any], - conversation_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None, - message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: - manager = PluginToolManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) yield from manager.invoke_first_step( tenant_id=self.tenant_id, user_id=user_id, - tool_provider=self.entity.identity.provider, - tool_name=self.entity.identity.name, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, - tool_parameters=tool_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, + datasource_parameters=datasource_parameters, + rag_pipeline_id=rag_pipeline_id, ) def _invoke_second_step( self, user_id: str, datasource_parameters: dict[str, Any], - conversation_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None, - message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: - manager = PluginToolManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) yield from manager.invoke( tenant_id=self.tenant_id, user_id=user_id, - tool_provider=self.entity.identity.provider, - tool_name=self.entity.identity.name, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, - tool_parameters=tool_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, + datasource_parameters=datasource_parameters, + rag_pipeline_id=rag_pipeline_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool": - return DatasourceTool( + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( entity=self.entity, runtime=runtime, tenant_id=self.tenant_id, @@ -87,9 +81,7 @@ class DatasourceTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + rag_pipeline_id: Optional[str] = None, ) -> list[DatasourceParameter]: """ get the runtime parameters @@ -100,16 +92,14 @@ class DatasourceTool(Tool): if self.runtime_parameters is not None: return self.runtime_parameters - manager = PluginToolManager() + manager = PluginDatasourceManager() self.runtime_parameters = manager.get_runtime_parameters( tenant_id=self.tenant_id, user_id="", provider=self.entity.identity.provider, - tool=self.entity.identity.name, + datasource=self.entity.identity.name, credentials=self.runtime.credentials, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, + rag_pipeline_id=rag_pipeline_id, ) return self.runtime_parameters diff --git a/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py deleted file mode 100644 index 032274b87e..0000000000 --- a/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ /dev/null @@ -1,199 +0,0 @@ -import threading -from typing import Any - -from flask import Flask, current_app -from pydantic import BaseModel, Field - -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.models.document import Document as RagDocument -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment - -default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, -} - - -class DatasetMultiRetrieverToolInput(BaseModel): - query: str = Field(..., description="dataset multi retriever and rerank") - - -class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): - """Tool for querying multi dataset.""" - - name: str = "dataset_" - args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput - description: str = "dataset multi retriever and rerank. " - dataset_ids: list[str] - reranking_provider_name: str - reranking_model_name: str - - @classmethod - def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): - return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs - ) - - def _run(self, query: str) -> str: - threads = [] - all_documents: list[RagDocument] = [] - for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset_id, - "query": query, - "all_documents": all_documents, - "hit_callbacks": self.hit_callbacks, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - # do rerank for searched documents - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider=self.reranking_provider_name, - model_type=ModelType.RERANK, - model=self.reranking_model_name, - ) - - rerank_runner = RerankModelRunner(rerank_model_instance) - all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) - - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(all_documents) - - document_score_list = {} - for item in all_documents: - if item.metadata and item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(self.dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") - else: - document_context_list.append(segment.get_sign_content()) - if self.return_resource: - context_list = [] - resource_number = 1 - for segment in sorted_segments: - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() - if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - "doc_metadata": document.doc_metadata, - } - - if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - context_list.append(source) - resource_number += 1 - - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) - return "" - - raise RuntimeError("not segments found") - - def _retriever( - self, - flask_app: Flask, - dataset_id: str, - query: str, - all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler], - ): - with flask_app.app_context(): - dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() - ) - - if not dataset: - return [] - - for hit_callback in hit_callbacks: - hit_callback.on_query(query, dataset.id) - - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model - - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", - dataset_id=dataset.id, - query=query, - top_k=retrieval_model.get("top_k") or 2, - ) - if documents: - all_documents.extend(documents) - else: - if self.top_k > 0: - # retrieval source - documents = RetrievalService.retrieve( - retrieval_method=retrieval_model["search_method"], - dataset_id=dataset.id, - query=query, - top_k=retrieval_model.get("top_k") or 2, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), - ) - - all_documents.extend(documents) diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py deleted file mode 100644 index a4d2de3b1c..0000000000 --- a/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py +++ /dev/null @@ -1,33 +0,0 @@ -from abc import abstractmethod -from typing import Any, Optional - -from msal_extensions.persistence import ABC # type: ignore -from pydantic import BaseModel, ConfigDict - -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler - - -class DatasetRetrieverBaseTool(BaseModel, ABC): - """Tool for querying a Dataset.""" - - name: str = "dataset" - description: str = "use this to retrieve a dataset. " - tenant_id: str - top_k: int = 2 - score_threshold: Optional[float] = None - hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] - return_resource: bool - retriever_from: str - model_config = ConfigDict(arbitrary_types_allowed=True) - - @abstractmethod - def _run( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - """Use the tool. - - Add run_manager: Optional[CallbackManagerForToolRun] = None - to child implementations to enable tracing, - """ diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py deleted file mode 100644 index 63260cfac3..0000000000 --- a/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.entities.context_entities import DocumentContext -from core.rag.models.document import Document as RetrievalDocument -from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from extensions.ext_database import db -from models.dataset import Dataset -from models.dataset import Document as DatasetDocument -from services.external_knowledge_service import ExternalDatasetService - -default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "reranking_mode": "reranking_model", - "top_k": 2, - "score_threshold_enabled": False, -} - - -class DatasetRetrieverToolInput(BaseModel): - query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") - - -class DatasetRetrieverTool(DatasetRetrieverBaseTool): - """Tool for querying a Dataset.""" - - name: str = "dataset" - args_schema: type[BaseModel] = DatasetRetrieverToolInput - description: str = "use this to retrieve a dataset. " - dataset_id: str - - @classmethod - def from_dataset(cls, dataset: Dataset, **kwargs): - description = dataset.description - if not description: - description = "useful for when you want to answer queries about the " + dataset.name - - description = description.replace("\n", "").replace("\r", "") - return cls( - name=f"dataset_{dataset.id.replace('-', '_')}", - tenant_id=dataset.tenant_id, - dataset_id=dataset.id, - description=description, - **kwargs, - ) - - def _run(self, query: str) -> str: - dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() - ) - - if not dataset: - return "" - for hit_callback in self.hit_callbacks: - hit_callback.on_query(query, dataset.id) - if dataset.provider == "external": - results = [] - external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id=dataset.tenant_id, - dataset_id=dataset.id, - query=query, - external_retrieval_parameters=dataset.retrieval_model, - ) - for external_document in external_documents: - document = RetrievalDocument( - page_content=external_document.get("content"), - metadata=external_document.get("metadata"), - provider="external", - ) - if document.metadata is not None: - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) - # deal with external documents - context_list = [] - for position, item in enumerate(results, start=1): - if item.metadata is not None: - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } - context_list.append(source) - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join([item.page_content for item in results])) - else: - # get retrieval model , if the model is not setting , using default - retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k - ) - return str("\n".join([document.page_content for document in documents])) - else: - if self.top_k > 0: - # retrieval source - documents = RetrievalService.retrieve( - retrieval_method=retrieval_model.get("search_method", "semantic_search"), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model") - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights"), - ) - else: - documents = [] - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(documents) - document_score_list = {} - if dataset.indexing_technique != "economy": - for item in documents: - if item.metadata is not None and item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] - records = RetrievalService.format_retrieval_documents(documents) - if records: - for record in records: - segment = record.segment - if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) - else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) - ) - retrieval_resource_list = [] - if self.return_resource: - for record in records: - segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() - if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, # type: ignore - "document_name": document.name, # type: ignore - "data_source_type": document.data_source_type, # type: ignore - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, # type: ignore - } - - if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - retrieval_resource_list.append(source) - - if self.return_resource and retrieval_resource_list: - retrieval_resource_list = sorted( - retrieval_resource_list, - key=lambda x: x.get("score") or 0.0, - reverse=True, - ) - for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore - item["position"] = position # type: ignore - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(retrieval_resource_list) - if document_context_list: - document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" diff --git a/api/core/datasource/utils/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever_tool.py deleted file mode 100644 index b73dec4ebc..0000000000 --- a/api/core/datasource/utils/dataset_retriever_tool.py +++ /dev/null @@ -1,134 +0,0 @@ -from collections.abc import Generator -from typing import Any, Optional - -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import InvokeFrom -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.tools.__base.tool import Tool -from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ToolDescription, - ToolEntity, - ToolIdentity, - ToolInvokeMessage, - ToolParameter, - ToolProviderType, -) -from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool - - -class DatasetRetrieverTool(Tool): - retrieval_tool: DatasetRetrieverBaseTool - - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: - super().__init__(entity, runtime) - self.retrieval_tool = retrieval_tool - - @staticmethod - def get_dataset_tools( - tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity | None, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler, - ) -> list["DatasetRetrieverTool"]: - """ - get dataset tool - """ - # check if retrieve_config is valid - if dataset_ids is None or len(dataset_ids) == 0: - return [] - if retrieve_config is None: - return [] - - feature = DatasetRetrieval() - - # save original retrieve strategy, and set retrieve strategy to SINGLE - # Agent only support SINGLE mode - original_retriever_mode = retrieve_config.retrieve_strategy - retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE - retrieval_tools = feature.to_dataset_retriever_tool( - tenant_id=tenant_id, - dataset_ids=dataset_ids, - retrieve_config=retrieve_config, - return_resource=return_resource, - invoke_from=invoke_from, - hit_callback=hit_callback, - ) - if retrieval_tools is None or len(retrieval_tools) == 0: - return [] - - # restore retrieve strategy - retrieve_config.retrieve_strategy = original_retriever_mode - - # convert retrieval tools to Tools - tools = [] - for retrieval_tool in retrieval_tools: - tool = DatasetRetrieverTool( - retrieval_tool=retrieval_tool, - entity=ToolEntity( - identity=ToolIdentity( - provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") - ), - parameters=[], - description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), - ), - runtime=ToolRuntime(tenant_id=tenant_id), - ) - - tools.append(tool) - - return tools - - def get_runtime_parameters( - self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - return [ - ToolParameter( - name="query", - label=I18nObject(en_US="", zh_Hans=""), - human_description=I18nObject(en_US="", zh_Hans=""), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description="Query for the dataset to be used to retrieve the dataset.", - required=True, - default="", - placeholder=I18nObject(en_US="", zh_Hans=""), - ), - ] - - def tool_provider_type(self) -> ToolProviderType: - return ToolProviderType.DATASET_RETRIEVAL - - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: - """ - invoke dataset retriever tool - """ - query = tool_parameters.get("query") - if not query: - yield self.create_text_message(text="please input query") - else: - # invoke dataset retriever tool - result = self.retrieval_tool._run(query=query) - yield self.create_text_message(text=result) - - def validate_credentials( - self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False - ) -> str | None: - """ - validate the credentials for dataset retriever tool - """ - pass diff --git a/api/core/datasource/utils/model_invocation_utils.py b/api/core/datasource/utils/model_invocation_utils.py deleted file mode 100644 index 3f59b3f472..0000000000 --- a/api/core/datasource/utils/model_invocation_utils.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - -Therefore, a model manager is needed to list/invoke/validate models. -""" - -import json -from typing import Optional, cast - -from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db -from models.tools import ToolModelInvoke - - -class InvokeModelError(Exception): - pass - - -class ModelInvocationUtils: - @staticmethod - def get_max_llm_context_tokens( - tenant_id: str, - ) -> int: - """ - get max llm context tokens of the model - """ - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - ) - - if not model_instance: - raise InvokeModelError("Model not found") - - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - - if not schema: - raise InvokeModelError("No model schema found") - - max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) - if max_tokens is None: - return 2048 - - return max_tokens - - @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: - """ - calculate tokens from prompt messages and model parameters - """ - - # get model instance - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) - - if not model_instance: - raise InvokeModelError("Model not found") - - # get tokens - tokens = model_instance.get_llm_num_tokens(prompt_messages) - - return tokens - - @staticmethod - def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] - ) -> LLMResult: - """ - invoke model with parameters in user's own context - - :param user_id: user id - :param tenant_id: tenant id, the tenant id of the creator of the tool - :param tool_type: tool type - :param tool_name: tool name - :param prompt_messages: prompt messages - :return: AssistantPromptMessage - """ - - # get model manager - model_manager = ModelManager() - # get model instance - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - ) - - # get prompt tokens - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - model_parameters = { - "temperature": 0.8, - "top_p": 0.8, - } - - # create tool model invoke - tool_model_invoke = ToolModelInvoke( - user_id=user_id, - tenant_id=tenant_id, - provider=model_instance.provider, - tool_type=tool_type, - tool_name=tool_name, - model_parameters=json.dumps(model_parameters), - prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response="", - prompt_tokens=prompt_tokens, - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency="USD", - ) - - db.session.add(tool_model_invoke) - db.session.commit() - - try: - response: LLMResult = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], - ), - ) - except InvokeRateLimitError as e: - raise InvokeModelError(f"Invoke rate limit error: {e}") - except InvokeBadRequestError as e: - raise InvokeModelError(f"Invoke bad request error: {e}") - except InvokeConnectionError as e: - raise InvokeModelError(f"Invoke connection error: {e}") - except InvokeAuthorizationError as e: - raise InvokeModelError("Invoke authorization error") - except InvokeServerUnavailableError as e: - raise InvokeModelError(f"Invoke server unavailable error: {e}") - except Exception as e: - raise InvokeModelError(f"Invoke error: {e}") - - # update tool model invoke - tool_model_invoke.model_response = response.message.content - if response.usage: - tool_model_invoke.answer_tokens = response.usage.completion_tokens - tool_model_invoke.answer_unit_price = response.usage.completion_unit_price - tool_model_invoke.answer_price_unit = response.usage.completion_price_unit - tool_model_invoke.provider_response_latency = response.usage.latency - tool_model_invoke.total_price = response.usage.total_price - tool_model_invoke.currency = response.usage.currency - - db.session.commit() - - return response diff --git a/api/core/datasource/utils/rag_web_reader.py b/api/core/datasource/utils/rag_web_reader.py deleted file mode 100644 index 22c47fa814..0000000000 --- a/api/core/datasource/utils/rag_web_reader.py +++ /dev/null @@ -1,17 +0,0 @@ -import re - - -def get_image_upload_file_ids(content): - pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" - matches = re.findall(pattern, content) - image_upload_file_ids = [] - for match in matches: - if match[1] == "file-preview": - content_pattern = r"files/([^/]+)/file-preview" - else: - content_pattern = r"files/([^/]+)/image-preview" - content_match = re.search(content_pattern, match[0]) - if content_match: - image_upload_file_id = content_match.group(1) - image_upload_file_ids.append(image_upload_file_id) - return image_upload_file_ids diff --git a/api/core/datasource/utils/web_reader_tool.py b/api/core/datasource/utils/web_reader_tool.py deleted file mode 100644 index d42fd99fce..0000000000 --- a/api/core/datasource/utils/web_reader_tool.py +++ /dev/null @@ -1,375 +0,0 @@ -import hashlib -import json -import mimetypes -import os -import re -import site -import subprocess -import tempfile -import unicodedata -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Literal, Optional, cast -from urllib.parse import unquote - -import chardet -import cloudscraper # type: ignore -from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore -from regex import regex # type: ignore - -from core.helper import ssrf_proxy -from core.rag.extractor import extract_processor -from core.rag.extractor.extract_processor import ExtractProcessor - -FULL_TEMPLATE = """ -TITLE: {title} -AUTHORS: {authors} -PUBLISH DATE: {publish_date} -TOP_IMAGE_URL: {top_image} -TEXT: - -{text} -""" - - -def page_result(text: str, cursor: int, max_length: int) -> str: - """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor : cursor + max_length] - - -def get_url(url: str, user_agent: Optional[str] = None) -> str: - """Fetch URL and return the contents as a string.""" - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" - " Chrome/91.0.4472.124 Safari/537.36" - } - if user_agent: - headers["User-Agent"] = user_agent - - main_content_type = None - supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) - - if response.status_code == 200: - # check content-type - content_type = response.headers.get("Content-Type") - if content_type: - main_content_type = response.headers.get("Content-Type").split(";")[0].strip() - else: - content_disposition = response.headers.get("Content-Disposition", "") - filename_match = re.search(r'filename="([^"]+)"', content_disposition) - if filename_match: - filename = unquote(filename_match.group(1)) - extension = re.search(r"\.(\w+)$", filename) - if extension: - main_content_type = mimetypes.guess_type(filename)[0] - - if main_content_type not in supported_content_types: - return "Unsupported content-type [{}] of URL.".format(main_content_type) - - if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) - - response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) - elif response.status_code == 403: - scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request - response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) - - if response.status_code != 200: - return "URL returned status code {}.".format(response.status_code) - - # Detect encoding using chardet - detected_encoding = chardet.detect(response.content) - encoding = detected_encoding["encoding"] - if encoding: - try: - content = response.content.decode(encoding) - except (UnicodeDecodeError, TypeError): - content = response.text - else: - content = response.text - - a = extract_using_readabilipy(content) - - if not a["plain_text"] or not a["plain_text"].strip(): - return "" - - res = FULL_TEMPLATE.format( - title=a["title"], - authors=a["byline"], - publish_date=a["date"], - top_image="", - text=a["plain_text"] or "", - ) - - return res - - -def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: - f_html.write(html) - f_html.close() - html_path = f_html.name - - # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file - article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path("readabilipy"), "javascript") - with chdir(jsdir): - subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) - - # Read output of call to Readability.parse() from JSON file and return as Python dictionary - input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) - - # Deleting files after processing - os.unlink(article_json_path) - os.unlink(html_path) - - article_json: dict[str, Any] = { - "title": None, - "byline": None, - "date": None, - "content": None, - "plain_content": None, - "plain_text": None, - } - # Populate article fields from readability fields where present - if input_json: - if input_json.get("title"): - article_json["title"] = input_json["title"] - if input_json.get("byline"): - article_json["byline"] = input_json["byline"] - if input_json.get("date"): - article_json["date"] = input_json["date"] - if input_json.get("content"): - article_json["content"] = input_json["content"] - article_json["plain_content"] = plain_content(article_json["content"], False, False) - article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) - if input_json.get("textContent"): - article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) - - return article_json - - -def find_module_path(module_name): - for package_path in site.getsitepackages(): - potential_path = os.path.join(package_path, module_name) - if os.path.exists(potential_path): - return potential_path - - return None - - -@contextmanager -def chdir(path): - """Change directory in context and return to original on exit""" - # From https://stackoverflow.com/a/37996581, couldn't find a built-in - original_path = os.getcwd() - os.chdir(path) - try: - yield - finally: - os.chdir(original_path) - - -def extract_text_blocks_as_plain_text(paragraph_html): - # Load article as DOM - soup = BeautifulSoup(paragraph_html, "html.parser") - # Select all lists - list_elements = soup.find_all(["ul", "ol"]) - # Prefix text in all list items with "* " and make lists paragraphs - for list_element in list_elements: - plain_items = "".join( - list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) - ) - list_element.string = plain_items - list_element.name = "p" - # Select all text blocks - text_blocks = [s.parent for s in soup.find_all(string=True)] - text_blocks = [plain_text_leaf_node(block) for block in text_blocks] - # Drop empty paragraphs - text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) - return text_blocks - - -def plain_text_leaf_node(element): - # Extract all text, stripped of any child HTML elements and normalize it - plain_text = normalize_text(element.get_text()) - if plain_text != "" and element.name == "li": - plain_text = "* {}, ".format(plain_text) - if plain_text == "": - plain_text = None - if "data-node-index" in element.attrs: - plain = {"node_index": element["data-node-index"], "text": plain_text} - else: - plain = {"text": plain_text} - return plain - - -def plain_content(readability_content, content_digests, node_indexes): - # Load article as DOM - soup = BeautifulSoup(readability_content, "html.parser") - # Make all elements plain - elements = plain_elements(soup.contents, content_digests, node_indexes) - if node_indexes: - # Add node index attributes to nodes - elements = [add_node_indexes(element) for element in elements] - # Replace article contents with plain elements - soup.contents = elements - return str(soup) - - -def plain_elements(elements, content_digests, node_indexes): - # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) for element in elements] - if content_digests: - # Add content digest attribute to nodes - elements = [add_content_digest(element) for element in elements] - return elements - - -def plain_element(element, content_digests, node_indexes): - # For lists, we make each item plain text - if is_leaf(element): - # For leaf node elements, extract the text content, discarding any HTML tags - # 1. Get element contents as text - plain_text = element.get_text() - # 2. Normalize the extracted text string to a canonical representation - plain_text = normalize_text(plain_text) - # 3. Update element content to be plain text - element.string = plain_text - elif is_text(element): - if is_non_printing(element): - # The simplified HTML may have come from Readability.js so might - # have non-printing text (e.g. Comment or CData). In this case, we - # keep the structure, but ensure that the string is empty. - element = type(element)("") - else: - plain_text = element.string - plain_text = normalize_text(plain_text) - element = type(element)(plain_text) - else: - # If not a leaf node or leaf type call recursively on child nodes, replacing - element.contents = plain_elements(element.contents, content_digests, node_indexes) - return element - - -def add_node_indexes(element, node_index="0"): - # Can't add attributes to string types - if is_text(element): - return element - # Add index to current element - element["data-node-index"] = node_index - # Add index to child elements - for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): - # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) - add_node_indexes(child, node_index=child_index) - return element - - -def normalize_text(text): - """Normalize unicode and whitespace.""" - # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them - text = strip_control_characters(text) - text = normalize_unicode(text) - text = normalize_whitespace(text) - return text - - -def strip_control_characters(text): - """Strip out unicode control characters which might break the parsing.""" - # Unicode control characters - # [Cc]: Other, Control [includes new lines] - # [Cf]: Other, Format - # [Cn]: Other, Not Assigned - # [Co]: Other, Private Use - # [Cs]: Other, Surrogate - control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} - retained_chars = ["\t", "\n", "\r", "\f"] - - # Remove non-printing control characters - return "".join( - [ - "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char - for char in text - ] - ) - - -def normalize_unicode(text): - """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" - text = unicodedata.normalize(normal_form, text) - return text - - -def normalize_whitespace(text): - """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" - text = regex.sub(r"\s+", " ", text) - # Remove leading and trailing whitespace - text = text.strip() - return text - - -def is_leaf(element): - return element.name in {"p", "li"} - - -def is_text(element): - return isinstance(element, NavigableString) - - -def is_non_printing(element): - return any(isinstance(element, _e) for _e in [Comment, CData]) - - -def add_content_digest(element): - if not is_text(element): - element["data-content-digest"] = content_digest(element) - return element - - -def content_digest(element): - digest: Any - if is_text(element): - # Hash - trimmed_string = element.string.strip() - if trimmed_string == "": - digest = "" - else: - digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() - else: - contents = element.contents - num_contents = len(contents) - if num_contents == 0: - # No hash when no child elements exist - digest = "" - elif num_contents == 1: - # If single child, use digest of child - digest = content_digest(contents[0]) - else: - # Build content digest from the "non-empty" digests of child nodes - digest = hashlib.sha256() - child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) - for child in child_digests: - digest.update(child.encode("utf-8")) - digest = digest.hexdigest() - return digest - - -def get_image_upload_file_ids(content): - pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" - matches = re.findall(pattern, content) - image_upload_file_ids = [] - for match in matches: - if match[1] == "file-preview": - content_pattern = r"files/([^/]+)/file-preview" - else: - content_pattern = r"files/([^/]+)/image-preview" - content_match = re.search(content_pattern, match[0]) - if content_match: - image_upload_file_id = content_match.group(1) - image_upload_file_ids.append(image_upload_file_id) - return image_upload_file_ids diff --git a/api/installed_plugins.jsonl b/api/installed_plugins.jsonl new file mode 100644 index 0000000000..463e24ae64 --- /dev/null +++ b/api/installed_plugins.jsonl @@ -0,0 +1 @@ +{"not_installed": [], "plugin_install_failed": []} \ No newline at end of file From 9437145218f2445d32ae1db77695f2b0a456fd83 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 25 Apr 2025 13:42:57 +0800 Subject: [PATCH 007/155] r2 --- .../datasource/datasource_tool/provider.py | 65 +++++++++++++++++++ api/core/datasource/datasource_tool/tool.py | 23 ++++--- .../entities/datasource_entities.py | 4 +- 3 files changed, 81 insertions(+), 11 deletions(-) diff --git a/api/core/datasource/datasource_tool/provider.py b/api/core/datasource/datasource_tool/provider.py index 3104728947..820224eeaa 100644 --- a/api/core/datasource/datasource_tool/provider.py +++ b/api/core/datasource/datasource_tool/provider.py @@ -78,3 +78,68 @@ class DatasourceToolProviderController(BuiltinToolProviderController): ) for datasource_entity in self.entity.datasources ] + + def validate_credentials_format(self, credentials: dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = dict[str, ProviderConfig]() + if credentials_schema is None: + return + + for credential in self.entity.credentials_schema: + credentials_schema[credential.name] = credential + + credentials_need_to_validate: dict[str, ProviderConfig] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.entity.identity.name}" + ) + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if not credential_schema.required and credentials[credential_name] is None: + continue + + if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + elif credential_schema.type == ProviderConfig.Type.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type in { + ProviderConfig.Type.SECRET_INPUT, + ProviderConfig.Type.TEXT_INPUT, + ProviderConfig.Type.SELECT, + }: + default_value = str(default_value) + + credentials[credential_name] = default_value \ No newline at end of file diff --git a/api/core/datasource/datasource_tool/tool.py b/api/core/datasource/datasource_tool/tool.py index 1c8572c2c5..d55c28a9b9 100644 --- a/api/core/datasource/datasource_tool/tool.py +++ b/api/core/datasource/datasource_tool/tool.py @@ -1,14 +1,16 @@ from collections.abc import Generator from typing import Any, Optional +from core.datasource.__base.datasource import Datasource from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceInvokeMessage, + DatasourceParameter, + DatasourceProviderType, +) from core.plugin.manager.datasource import PluginDatasourceManager -from core.plugin.manager.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format -from core.tools.__base.tool import Tool -from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType class DatasourcePlugin(Datasource): @@ -16,11 +18,14 @@ class DatasourcePlugin(Datasource): icon: str plugin_unique_identifier: str runtime_parameters: Optional[list[DatasourceParameter]] + entity: DatasourceEntity + runtime: DatasourceRuntime def __init__( - self, entity: DatasourceEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str + self, entity: DatasourceEntity, runtime: DatasourceRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str ) -> None: - super().__init__(entity, runtime) + self.entity = entity + self.runtime = runtime self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier @@ -34,7 +39,7 @@ class DatasourcePlugin(Datasource): user_id: str, datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) @@ -54,7 +59,7 @@ class DatasourcePlugin(Datasource): user_id: str, datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 39c28c0d7d..de580b270e 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -105,7 +105,7 @@ class ApiProviderAuthType(Enum): raise ValueError(f"invalid mode value {value}") -class ToolInvokeMessage(BaseModel): +class DatasourceInvokeMessage(BaseModel): class TextMessage(BaseModel): text: str @@ -200,7 +200,7 @@ class ToolInvokeMessage(BaseModel): return v -class ToolInvokeMessageBinary(BaseModel): +class DatasourceInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") file_var: Optional[dict[str, Any]] = None From 389f15f8e3076bbb82a2d06239fd8f40e7d2b4a4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 25 Apr 2025 14:56:22 +0800 Subject: [PATCH 008/155] r2 --- api/contexts/__init__.py | 9 + api/core/datasource/__base/datasource.py | 221 ----- .../tool.py => __base/datasource_plugin.py} | 3 +- .../datasource_provider.py} | 18 +- api/core/datasource/__base/tool_provider.py | 109 --- .../datasource/datasource_file_manager.py | 244 +++++ api/core/datasource/datasource_manager.py | 95 ++ api/core/datasource/entities/constants.py | 2 +- .../entities/datasource_entities.py | 4 +- api/core/datasource/entities/file_entities.py | 1 - api/core/datasource/entities/tool_bundle.py | 29 - api/core/datasource/tool_manager.py | 870 ------------------ api/core/file/datasource_file_parser.py | 15 + 13 files changed, 376 insertions(+), 1244 deletions(-) delete mode 100644 api/core/datasource/__base/datasource.py rename api/core/datasource/{datasource_tool/tool.py => __base/datasource_plugin.py} (97%) rename api/core/datasource/{datasource_tool/provider.py => __base/datasource_provider.py} (92%) delete mode 100644 api/core/datasource/__base/tool_provider.py create mode 100644 api/core/datasource/datasource_file_manager.py create mode 100644 api/core/datasource/datasource_manager.py delete mode 100644 api/core/datasource/entities/file_entities.py delete mode 100644 api/core/datasource/entities/tool_bundle.py delete mode 100644 api/core/datasource/tool_manager.py create mode 100644 api/core/file/datasource_file_parser.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 127b8fe76d..7dac252201 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -3,6 +3,7 @@ from threading import Lock from typing import TYPE_CHECKING from contexts.wrapper import RecyclableContextVar +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity @@ -37,3 +38,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( ContextVar("plugin_model_schemas") ) + +datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = RecyclableContextVar( + ContextVar("datasource_plugin_providers") +) + +datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( + ContextVar("datasource_plugin_providers_lock") +) diff --git a/api/core/datasource/__base/datasource.py b/api/core/datasource/__base/datasource.py deleted file mode 100644 index 3a67b56e32..0000000000 --- a/api/core/datasource/__base/datasource.py +++ /dev/null @@ -1,221 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Generator -from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional - -if TYPE_CHECKING: - from models.model import File - -from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType -from core.tools.entities.tool_entities import ( - ToolInvokeMessage, - ToolParameter, -) - - -class Datasource(ABC): - """ - The base class of a datasource - """ - - entity: DatasourceEntity - runtime: DatasourceRuntime - - def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None: - self.entity = entity - self.runtime = runtime - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource": - """ - fork a new datasource with metadata - :return: the new datasource - """ - return self.__class__( - entity=self.entity.model_copy(), - runtime=runtime, - ) - - @abstractmethod - def datasource_provider_type(self) -> DatasourceProviderType: - """ - get the datasource provider type - - :return: the tool provider type - """ - - def invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage]: - if self.runtime and self.runtime.runtime_parameters: - tool_parameters.update(self.runtime.runtime_parameters) - - # try parse tool parameters into the correct type - tool_parameters = self._transform_tool_parameters_type(tool_parameters) - - result = self._invoke( - user_id=user_id, - tool_parameters=tool_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, - ) - - if isinstance(result, ToolInvokeMessage): - - def single_generator() -> Generator[ToolInvokeMessage, None, None]: - yield result - - return single_generator() - elif isinstance(result, list): - - def generator() -> Generator[ToolInvokeMessage, None, None]: - yield from result - - return generator() - else: - return result - - def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: - """ - Transform tool parameters type - """ - # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result = deepcopy(tool_parameters) - for parameter in self.entity.parameters or []: - if parameter.name in tool_parameters: - result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) - - return result - - @abstractmethod - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: - pass - - def get_runtime_parameters( - self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - """ - get the runtime parameters - - interface for developer to dynamic change the parameters of a tool depends on the variables pool - - :return: the runtime parameters - """ - return self.entity.parameters - - def get_merged_runtime_parameters( - self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - """ - get merged runtime parameters - - :return: merged runtime parameters - """ - parameters = self.entity.parameters - parameters = parameters.copy() - user_parameters = self.get_runtime_parameters() or [] - user_parameters = user_parameters.copy() - - # override parameters - for parameter in user_parameters: - # check if parameter in tool parameters - for tool_parameter in parameters: - if tool_parameter.name == parameter.name: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description - break - else: - # add new parameter - parameters.append(parameter) - - return parameters - - def create_image_message( - self, - image: str, - ) -> ToolInvokeMessage: - """ - create an image message - - :param image: the url of the image - :return: the image message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) - ) - - def create_file_message(self, file: "File") -> ToolInvokeMessage: - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.FILE, - message=ToolInvokeMessage.FileMessage(), - meta={"file": file}, - ) - - def create_link_message(self, link: str) -> ToolInvokeMessage: - """ - create a link message - - :param link: the url of the link - :return: the link message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link) - ) - - def create_text_message(self, text: str) -> ToolInvokeMessage: - """ - create a text message - - :param text: the text - :return: the text message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage(text=text), - ) - - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: - """ - create a blob message - - :param blob: the blob - :param meta: the meta info of blob object - :return: the blob message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=blob), - meta=meta, - ) - - def create_json_message(self, object: dict) -> ToolInvokeMessage: - """ - create a json message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) - ) diff --git a/api/core/datasource/datasource_tool/tool.py b/api/core/datasource/__base/datasource_plugin.py similarity index 97% rename from api/core/datasource/datasource_tool/tool.py rename to api/core/datasource/__base/datasource_plugin.py index d55c28a9b9..037c0f4630 100644 --- a/api/core/datasource/datasource_tool/tool.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,7 +1,6 @@ from collections.abc import Generator from typing import Any, Optional -from core.datasource.__base.datasource import Datasource from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, @@ -13,7 +12,7 @@ from core.plugin.manager.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format -class DatasourcePlugin(Datasource): +class DatasourcePlugin: tenant_id: str icon: str plugin_unique_identifier: str diff --git a/api/core/datasource/datasource_tool/provider.py b/api/core/datasource/__base/datasource_provider.py similarity index 92% rename from api/core/datasource/datasource_tool/provider.py rename to api/core/datasource/__base/datasource_provider.py index 820224eeaa..ba66e2a3c4 100644 --- a/api/core/datasource/datasource_tool/provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -1,15 +1,15 @@ from typing import Any -from core.datasource.datasource_tool.tool import DatasourceTool +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.entities.provider_entities import ProviderConfig from core.plugin.manager.tool import PluginToolManager -from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType from core.tools.errors import ToolProviderCredentialValidationError -class DatasourceToolProviderController(BuiltinToolProviderController): +class DatasourcePluginProviderController(BuiltinToolProviderController): entity: DatasourceProviderEntityWithPlugin tenant_id: str plugin_id: str @@ -45,7 +45,7 @@ class DatasourceToolProviderController(BuiltinToolProviderController): ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_datasource(self, datasource_name: str) -> DatasourceTool: # type: ignore + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -56,9 +56,9 @@ class DatasourceToolProviderController(BuiltinToolProviderController): if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return DatasourceTool( + return DatasourcePlugin( entity=datasource_entity, - runtime=ToolRuntime(tenant_id=self.tenant_id), + runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, icon=self.entity.identity.icon, plugin_unique_identifier=self.plugin_unique_identifier, @@ -69,9 +69,9 @@ class DatasourceToolProviderController(BuiltinToolProviderController): get all datasources """ return [ - DatasourceTool( + DatasourcePlugin( entity=datasource_entity, - runtime=ToolRuntime(tenant_id=self.tenant_id), + runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, icon=self.entity.identity.icon, plugin_unique_identifier=self.plugin_unique_identifier, diff --git a/api/core/datasource/__base/tool_provider.py b/api/core/datasource/__base/tool_provider.py deleted file mode 100644 index d096fc7df7..0000000000 --- a/api/core/datasource/__base/tool_provider.py +++ /dev/null @@ -1,109 +0,0 @@ -from abc import ABC, abstractmethod -from copy import deepcopy -from typing import Any - -from core.entities.provider_entities import ProviderConfig -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolProviderEntity, - ToolProviderType, -) -from core.tools.errors import ToolProviderCredentialValidationError - - -class ToolProviderController(ABC): - entity: ToolProviderEntity - - def __init__(self, entity: ToolProviderEntity) -> None: - self.entity = entity - - def get_credentials_schema(self) -> list[ProviderConfig]: - """ - returns the credentials schema of the provider - - :return: the credentials schema - """ - return deepcopy(self.entity.credentials_schema) - - @abstractmethod - def get_tool(self, tool_name: str) -> Tool: - """ - returns a tool that the provider can provide - - :return: tool - """ - pass - - @property - def provider_type(self) -> ToolProviderType: - """ - returns the type of the provider - - :return: type of the provider - """ - return ToolProviderType.BUILT_IN - - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: - """ - validate the format of the credentials of the provider and set the default value if needed - - :param credentials: the credentials of the tool - """ - credentials_schema = dict[str, ProviderConfig]() - if credentials_schema is None: - return - - for credential in self.entity.credentials_schema: - credentials_schema[credential.name] = credential - - credentials_need_to_validate: dict[str, ProviderConfig] = {} - for credential_name in credentials_schema: - credentials_need_to_validate[credential_name] = credentials_schema[credential_name] - - for credential_name in credentials: - if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError( - f"credential {credential_name} not found in provider {self.entity.identity.name}" - ) - - # check type - credential_schema = credentials_need_to_validate[credential_name] - if not credential_schema.required and credentials[credential_name] is None: - continue - - if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: - if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") - - elif credential_schema.type == ProviderConfig.Type.SELECT: - if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") - - options = credential_schema.options - if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") - - if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError( - f"credential {credential_name} should be one of {options}" - ) - - credentials_need_to_validate.pop(credential_name) - - for credential_name in credentials_need_to_validate: - credential_schema = credentials_need_to_validate[credential_name] - if credential_schema.required: - raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") - - # the credential is not set currently, set the default value if needed - if credential_schema.default is not None: - default_value = credential_schema.default - # parse default value into the correct type - if credential_schema.type in { - ProviderConfig.Type.SECRET_INPUT, - ProviderConfig.Type.TEXT_INPUT, - ProviderConfig.Type.SELECT, - }: - default_value = str(default_value) - - credentials[credential_name] = default_value diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py new file mode 100644 index 0000000000..6704d4e73a --- /dev/null +++ b/api/core/datasource/datasource_file_manager.py @@ -0,0 +1,244 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from mimetypes import guess_extension, guess_type +from typing import Optional, Union +from uuid import uuid4 + +import httpx + +from configs import dify_config +from core.helper import ssrf_proxy +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.enums import CreatedByRole +from models.model import MessageFile, UploadFile +from models.tools import ToolFile + +logger = logging.getLogger(__name__) + + +class DatasourceFileManager: + @staticmethod + def sign_file(datasource_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @staticmethod + def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def create_file_by_raw( + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, + filename: Optional[str] = None, + ) -> UploadFile: + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + unique_filename = f"{unique_name}{extension}" + # default just as before + present_filename = unique_filename + if filename is not None: + has_extension = len(filename.split(".")) > 1 + # Add extension flexibly + present_filename = filename if has_extension else f"{filename}{extension}" + filepath = f"datasources/{tenant_id}/{unique_filename}" + storage.save(filepath, file_binary) + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=filepath, + name=present_filename, + size=len(file_binary), + extension=extension, + mime_type=mimetype, + created_by_role=CreatedByRole.ACCOUNT, + created_by=user_id, + used=False, + hash=hashlib.sha3_256(file_binary).hexdigest(), + source_url="", + ) + + db.session.add(upload_file) + db.session.commit() + db.session.refresh(upload_file) + + return upload_file + + @staticmethod + def create_file_by_url( + user_id: str, + tenant_id: str, + file_url: str, + conversation_id: Optional[str] = None, + ) -> UploadFile: + # try to download image + try: + response = ssrf_proxy.get(file_url) + response.raise_for_status() + blob = response.content + except httpx.TimeoutException: + raise ValueError(f"timeout when downloading file from {file_url}") + + mimetype = ( + guess_type(file_url)[0] + or response.headers.get("Content-Type", "").split(";")[0].strip() + or "application/octet-stream" + ) + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=filepath, + name=filename, + size=len(blob), + extension=extension, + mime_type=mimetype, + created_by_role=CreatedByRole.ACCOUNT, + created_by=user_id, + used=False, + hash=hashlib.sha3_256(blob).hexdigest(), + source_url=file_url, + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + upload_file: UploadFile | None = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == id, + ) + .first() + ) + + if not upload_file: + return None + + blob = storage.load_once(upload_file.key) + + return blob, upload_file.mime_type + + @staticmethod + def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile | None = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) + + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None + else: + tool_file_id = None + + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_generator_by_upload_file_id(upload_file_id: str): + """ + get file binary + + :param tool_file_id: the id of the tool file + + :return: the binary of the file, mime type + """ + upload_file: UploadFile | None = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == upload_file_id, + ) + .first() + ) + + if not upload_file: + return None, None + + stream = storage.load_stream(upload_file.key) + + return stream, upload_file.mime_type + + +# init tool_file_parser +from core.file.datasource_file_parser import datasource_file_manager + +datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py new file mode 100644 index 0000000000..9cde52e39f --- /dev/null +++ b/api/core/datasource/datasource_manager.py @@ -0,0 +1,95 @@ + +import logging +from threading import Lock +from typing import Union + +import contexts +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.entities.common_entities import I18nObject +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.datasource.errors import ToolProviderNotFoundError +from core.plugin.manager.tool import PluginToolManager + +logger = logging.getLogger(__name__) + + +class DatasourceManager: + _builtin_provider_lock = Lock() + _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {} + _builtin_providers_loaded = False + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + + @classmethod + def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController: + """ + get the datasource plugin provider + """ + # check if context is set + try: + contexts.datasource_plugin_providers.get() + except LookupError: + contexts.datasource_plugin_providers.set({}) + contexts.datasource_plugin_providers_lock.set(Lock()) + + with contexts.datasource_plugin_providers_lock.get(): + datasource_plugin_providers = contexts.datasource_plugin_providers.get() + if provider in datasource_plugin_providers: + return datasource_plugin_providers[provider] + + manager = PluginToolManager() + provider_entity = manager.fetch_tool_provider(tenant_id, provider) + if not provider_entity: + raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + + controller = DatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + + datasource_plugin_providers[provider] = controller + + return controller + + @classmethod + def get_datasource_runtime( + cls, + provider_type: DatasourceProviderType, + provider_id: str, + datasource_name: str, + tenant_id: str, + ) -> DatasourcePlugin: + """ + get the datasource runtime + + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :param datasource_name: the name of the datasource + :param tenant_id: the tenant id + + :return: the datasource plugin + """ + if provider_type == DatasourceProviderType.RAG_PIPELINE: + return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) + else: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + + + @classmethod + def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: + """ + list all the datasource providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_tool_providers(tenant_id) + return [ + DatasourcePluginProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + for provider in provider_entities + ] diff --git a/api/core/datasource/entities/constants.py b/api/core/datasource/entities/constants.py index 199c9f0d53..a4dbf6f11f 100644 --- a/api/core/datasource/entities/constants.py +++ b/api/core/datasource/entities/constants.py @@ -1 +1 @@ -TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__" +DATASOURCE_SELECTOR_MODEL_IDENTITY = "__dify__datasource_selector__" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index de580b270e..80e89ef1a9 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator +from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( PluginParameter, @@ -16,7 +17,6 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY class ToolLabelEnum(Enum): @@ -400,7 +400,7 @@ class DatasourceInvokeFrom(Enum): class DatasourceSelector(BaseModel): - dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY + dify_model_identity: str = DATASOURCE_SELECTOR_MODEL_IDENTITY class Parameter(BaseModel): name: str = Field(..., description="The name of the parameter") diff --git a/api/core/datasource/entities/file_entities.py b/api/core/datasource/entities/file_entities.py deleted file mode 100644 index 8b13789179..0000000000 --- a/api/core/datasource/entities/file_entities.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/core/datasource/entities/tool_bundle.py b/api/core/datasource/entities/tool_bundle.py deleted file mode 100644 index ffeeabbc1c..0000000000 --- a/api/core/datasource/entities/tool_bundle.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from core.tools.entities.tool_entities import ToolParameter - - -class ApiToolBundle(BaseModel): - """ - This class is used to store the schema information of an api based tool. - such as the url, the method, the parameters, etc. - """ - - # server_url - server_url: str - # method - method: str - # summary - summary: Optional[str] = None - # operation_id - operation_id: Optional[str] = None - # parameters - parameters: Optional[list[ToolParameter]] = None - # author - author: str - # icon - icon: Optional[str] = None - # openapi operation - openapi: dict diff --git a/api/core/datasource/tool_manager.py b/api/core/datasource/tool_manager.py deleted file mode 100644 index f2d0b74f7c..0000000000 --- a/api/core/datasource/tool_manager.py +++ /dev/null @@ -1,870 +0,0 @@ -import json -import logging -import mimetypes -from collections.abc import Generator -from os import listdir, path -from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast - -from yarl import URL - -import contexts -from core.plugin.entities.plugin import ToolProviderID -from core.plugin.manager.tool import PluginToolManager -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.plugin_tool.tool import PluginTool -from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - -if TYPE_CHECKING: - from core.workflow.nodes.tool.entities import ToolEntity - - -from configs import dify_config -from core.agent.entities import AgentToolEntity -from core.app.entities.app_invoke_entities import InvokeFrom -from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.__base.tool import Tool -from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.custom_tool.tool import ApiTool -from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, - ToolProviderType, -) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError -from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ProviderConfigEncrypter, - ToolParameterConfigurationManager, -) -from core.tools.workflow_as_tool.tool import WorkflowTool -from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider -from services.tools.tools_transform_service import ToolTransformService - -logger = logging.getLogger(__name__) - - -class ToolManager: - _builtin_provider_lock = Lock() - _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} - _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} - - @classmethod - def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: - """ - get the hardcoded provider - """ - if len(cls._hardcoded_providers) == 0: - # init the builtin providers - cls.load_hardcoded_providers_cache() - - return cls._hardcoded_providers[provider] - - @classmethod - def get_builtin_provider( - cls, provider: str, tenant_id: str - ) -> BuiltinToolProviderController | PluginToolProviderController: - """ - get the builtin provider - - :param provider: the name of the provider - :param tenant_id: the id of the tenant - :return: the provider - """ - # split provider to - - if len(cls._hardcoded_providers) == 0: - # init the builtin providers - cls.load_hardcoded_providers_cache() - - if provider not in cls._hardcoded_providers: - # get plugin provider - plugin_provider = cls.get_plugin_provider(provider, tenant_id) - if plugin_provider: - return plugin_provider - - return cls._hardcoded_providers[provider] - - @classmethod - def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController: - """ - get the plugin provider - """ - # check if context is set - try: - contexts.plugin_tool_providers.get() - except LookupError: - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(Lock()) - - with contexts.plugin_tool_providers_lock.get(): - plugin_tool_providers = contexts.plugin_tool_providers.get() - if provider in plugin_tool_providers: - return plugin_tool_providers[provider] - - manager = PluginToolManager() - provider_entity = manager.fetch_tool_provider(tenant_id, provider) - if not provider_entity: - raise ToolProviderNotFoundError(f"plugin provider {provider} not found") - - controller = PluginToolProviderController( - entity=provider_entity.declaration, - plugin_id=provider_entity.plugin_id, - plugin_unique_identifier=provider_entity.plugin_unique_identifier, - tenant_id=tenant_id, - ) - - plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool - - @classmethod - def get_tool_runtime( - cls, - provider_type: ToolProviderType, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: - """ - get the tool runtime - - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :param tool_name: the name of the tool - :param tenant_id: the tenant id - :param invoke_from: invoke from - :param tool_invoke_from: the tool invoke from - - :return: the tool - """ - if provider_type == ToolProviderType.BUILT_IN: - # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_id, tenant_id) - - builtin_tool = provider_controller.get_tool(tool_name) - if not builtin_tool: - raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") - - if not provider_controller.need_credentials: - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - - if isinstance(provider_controller, PluginToolProviderController): - provider_id_entity = ToolProviderID(provider_id) - # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .first() - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) - .first() - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - - # decrypt the credentials - credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - - decrypted_credentials = tool_configuration.decrypt(credentials) - - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=decrypted_credentials, - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - - elif provider_type == ToolProviderType.API: - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, - ) - decrypted_credentials = tool_configuration.decrypt(credentials) - - return cast( - ApiTool, - api_provider.get_tool(tool_name).fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=decrypted_credentials, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider = ( - db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() - ) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) - if controller_tools is None or len(controller_tools) == 0: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - return cast( - WorkflowTool, - controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - elif provider_type == ToolProviderType.APP: - raise NotImplementedError("app provider not implemented") - elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) - else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") - - @classmethod - def get_agent_tool_runtime( - cls, - tenant_id: str, - app_id: str, - agent_tool: AgentToolEntity, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - ) -> Tool: - """ - get the agent tool runtime - """ - tool_entity = cls.get_tool_runtime( - provider_type=agent_tool.provider_type, - provider_id=agent_tool.provider_id, - tool_name=agent_tool.tool_name, - tenant_id=tenant_id, - invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT, - ) - runtime_parameters = {} - parameters = tool_entity.get_merged_runtime_parameters() - for parameter in parameters: - # check file types - if ( - parameter.type - in { - ToolParameter.ToolParameterType.SYSTEM_FILES, - ToolParameter.ToolParameterType.FILE, - ToolParameter.ToolParameterType.FILES, - } - and parameter.required - ): - raise ValueError(f"file type parameter {parameter.name} not supported in agent") - - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # save tool parameter to tool entity memory - value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) - runtime_parameters[parameter.name] = value - - # decrypt runtime parameters - encryption_manager = ToolParameterConfigurationManager( - tenant_id=tenant_id, - tool_runtime=tool_entity, - provider_name=agent_tool.provider_id, - provider_type=agent_tool.provider_type, - identity_id=f"AGENT.{app_id}", - ) - runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: - raise ValueError("runtime not found or runtime parameters not found") - - tool_entity.runtime.runtime_parameters.update(runtime_parameters) - return tool_entity - - @classmethod - def get_workflow_tool_runtime( - cls, - tenant_id: str, - app_id: str, - node_id: str, - workflow_tool: "ToolEntity", - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - ) -> Tool: - """ - get the workflow tool runtime - """ - tool_runtime = cls.get_tool_runtime( - provider_type=workflow_tool.provider_type, - provider_id=workflow_tool.provider_id, - tool_name=workflow_tool.tool_name, - tenant_id=tenant_id, - invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, - ) - runtime_parameters = {} - parameters = tool_runtime.get_merged_runtime_parameters() - - for parameter in parameters: - # save tool parameter to tool entity memory - if parameter.form == ToolParameter.ToolParameterForm.FORM: - value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) - runtime_parameters[parameter.name] = value - - # decrypt runtime parameters - encryption_manager = ToolParameterConfigurationManager( - tenant_id=tenant_id, - tool_runtime=tool_runtime, - provider_name=workflow_tool.provider_id, - provider_type=workflow_tool.provider_type, - identity_id=f"WORKFLOW.{app_id}.{node_id}", - ) - - if runtime_parameters: - runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - - tool_runtime.runtime.runtime_parameters.update(runtime_parameters) - return tool_runtime - - @classmethod - def get_tool_runtime_from_plugin( - cls, - tool_type: ToolProviderType, - tenant_id: str, - provider: str, - tool_name: str, - tool_parameters: dict[str, Any], - ) -> Tool: - """ - get tool runtime from plugin - """ - tool_entity = cls.get_tool_runtime( - provider_type=tool_type, - provider_id=provider, - tool_name=tool_name, - tenant_id=tenant_id, - invoke_from=InvokeFrom.SERVICE_API, - tool_invoke_from=ToolInvokeFrom.PLUGIN, - ) - runtime_parameters = {} - parameters = tool_entity.get_merged_runtime_parameters() - for parameter in parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # save tool parameter to tool entity memory - value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name)) - runtime_parameters[parameter.name] = value - - tool_entity.runtime.runtime_parameters.update(runtime_parameters) - return tool_entity - - @classmethod - def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]: - """ - get the absolute path of the icon of the hardcoded provider - - :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon - """ - # get provider - provider_controller = cls.get_hardcoded_provider(provider) - - absolute_path = path.join( - path.dirname(path.realpath(__file__)), - "builtin_tool", - "providers", - provider, - "_assets", - provider_controller.entity.identity.icon, - ) - # check if the icon exists - if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") - - # get the mime type - mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or "application/octet-stream" - - return absolute_path, mime_type - - @classmethod - def list_hardcoded_providers(cls): - # use cache first - if cls._builtin_providers_loaded: - yield from list(cls._hardcoded_providers.values()) - return - - with cls._builtin_provider_lock: - if cls._builtin_providers_loaded: - yield from list(cls._hardcoded_providers.values()) - return - - yield from cls._list_hardcoded_providers() - - @classmethod - def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]: - """ - list all the plugin providers - """ - manager = PluginToolManager() - provider_entities = manager.fetch_tool_providers(tenant_id) - return [ - PluginToolProviderController( - entity=provider.declaration, - plugin_id=provider.plugin_id, - plugin_unique_identifier=provider.plugin_unique_identifier, - tenant_id=tenant_id, - ) - for provider in provider_entities - ] - - @classmethod - def list_builtin_providers( - cls, tenant_id: str - ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]: - """ - list all the builtin providers - """ - yield from cls.list_hardcoded_providers() - # get plugin providers - yield from cls.list_plugin_providers(tenant_id) - - @classmethod - def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: - """ - list all the builtin providers - """ - for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")): - if provider_path.startswith("__"): - continue - - if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)): - if provider_path.startswith("__"): - continue - - # init provider - try: - provider_class = load_single_subclass_from_source( - module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}", - script_path=path.join( - path.dirname(path.realpath(__file__)), - "builtin_tool", - "providers", - provider_path, - f"{provider_path}.py", - ), - parent_type=BuiltinToolProviderController, - ) - provider: BuiltinToolProviderController = provider_class() - cls._hardcoded_providers[provider.entity.identity.name] = provider - for tool in provider.get_tools(): - cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label - yield provider - - except Exception: - logger.exception(f"load builtin provider {provider}") - continue - # set builtin providers loaded - cls._builtin_providers_loaded = True - - @classmethod - def load_hardcoded_providers_cache(cls): - for _ in cls.list_hardcoded_providers(): - pass - - @classmethod - def clear_hardcoded_providers_cache(cls): - cls._hardcoded_providers = {} - cls._builtin_providers_loaded = False - - @classmethod - def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: - """ - get the tool label - - :param tool_name: the name of the tool - - :return: the label of the tool - """ - if len(cls._builtin_tools_labels) == 0: - # init the builtin providers - cls.load_hardcoded_providers_cache() - - if tool_name not in cls._builtin_tools_labels: - return None - - return cls._builtin_tools_labels[tool_name] - - @classmethod - def list_providers_from_api( - cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral - ) -> list[ToolProviderApiEntity]: - result_providers: dict[str, ToolProviderApiEntity] = {} - - filters = [] - if not typ: - filters.extend(["builtin", "api", "workflow"]) - else: - filters.append(typ) - - with db.session.no_autoflush: - if "builtin" in filters: - # get builtin providers - builtin_providers = cls.list_builtin_providers(tenant_id) - - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) - - # append builtin providers - for provider in builtin_providers: - # handle include, exclude - if is_filtered( - include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), - data=provider, - name_func=lambda x: x.identity.name, - ): - continue - - user_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), - decrypt_credentials=False, - ) - - if isinstance(provider, PluginToolProviderController): - result_providers[f"plugin_provider.{user_provider.name}"] = user_provider - else: - result_providers[f"builtin_provider.{user_provider.name}"] = user_provider - - # get db api providers - - if "api" in filters: - db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() - ) - - api_provider_controllers: list[dict[str, Any]] = [ - {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} - for provider in db_api_providers - ] - - # get labels - labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) - - for api_provider_controller in api_provider_controllers: - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller["controller"], - db_provider=api_provider_controller["provider"], - decrypt_credentials=False, - labels=labels.get(api_provider_controller["controller"].provider_id, []), - ) - result_providers[f"api_provider.{user_provider.name}"] = user_provider - - if "workflow" in filters: - # get workflow providers - workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() - ) - - workflow_provider_controllers: list[WorkflowToolProviderController] = [] - for provider in workflow_providers: - try: - workflow_provider_controllers.append( - ToolTransformService.workflow_provider_to_controller(db_provider=provider) - ) - except Exception: - # app has been deleted - pass - - labels = ToolLabelManager.get_tools_labels( - [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] - ) - - for provider_controller in workflow_provider_controllers: - user_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=provider_controller, - labels=labels.get(provider_controller.provider_id, []), - ) - result_providers[f"workflow_provider.{user_provider.name}"] = user_provider - - return BuiltinToolProviderSort.sort(list(result_providers.values())) - - @classmethod - def get_api_provider_controller( - cls, tenant_id: str, provider_id: str - ) -> tuple[ApiToolProviderController, dict[str, Any]]: - """ - get the api provider - - :param tenant_id: the id of the tenant - :param provider_id: the id of the provider - - :return: the provider controller, the credentials - """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ) - .first() - ) - - if provider is None: - raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - - controller = ApiToolProviderController.from_db( - provider, - ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, - ) - controller.load_bundled_tools(provider.tools) - - return controller, provider.credentials - - @classmethod - def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: - """ - get api provider - """ - """ - get tool provider - """ - provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ) - .first() - ) - - if provider_obj is None: - raise ValueError(f"you have not added provider {provider_name}") - - try: - credentials = json.loads(provider_obj.credentials_str) or {} - except Exception: - credentials = {} - - # package tool provider controller - controller = ApiToolProviderController.from_db( - provider_obj, - ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, - ) - # init tool configuration - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, - ) - - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) - - try: - icon = json.loads(provider_obj.icon) - except Exception: - icon = {"background": "#252525", "content": "\ud83d\ude01"} - - # add tool labels - labels = ToolLabelManager.get_tool_labels(controller) - - return cast( - dict, - jsonable_encoder( - { - "schema_type": provider_obj.schema_type, - "schema": provider_obj.schema, - "tools": provider_obj.tools, - "icon": icon, - "description": provider_obj.description, - "credentials": masked_credentials, - "privacy_policy": provider_obj.privacy_policy, - "custom_disclaimer": provider_obj.custom_disclaimer, - "labels": labels, - } - ), - ) - - @classmethod - def generate_builtin_tool_icon_url(cls, provider_id: str) -> str: - return str( - URL(dify_config.CONSOLE_API_URL or "/") - / "console" - / "api" - / "workspaces" - / "current" - / "tool-provider" - / "builtin" - / provider_id - / "icon" - ) - - @classmethod - def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str: - return str( - URL(dify_config.CONSOLE_API_URL or "/") - / "console" - / "api" - / "workspaces" - / "current" - / "plugin" - / "icon" - % {"tenant_id": tenant_id, "filename": filename} - ) - - @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: - try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() - ) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - icon: dict = json.loads(workflow_provider.icon) - return icon - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - - @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: - try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() - ) - - if api_provider is None: - raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - - icon: dict = json.loads(api_provider.icon) - return icon - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - - @classmethod - def get_tool_icon( - cls, - tenant_id: str, - provider_type: ToolProviderType, - provider_id: str, - ) -> Union[str, dict]: - """ - get the tool icon - - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: - """ - provider_type = provider_type - provider_id = provider_id - if provider_type == ToolProviderType.BUILT_IN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - return cls.generate_builtin_tool_icon_url(provider_id) - elif provider_type == ToolProviderType.API: - return cls.generate_api_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.WORKFLOW: - return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.PLUGIN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - raise ValueError(f"plugin provider {provider_id} not found") - else: - raise ValueError(f"provider type {provider_type} not found") - - -ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/file/datasource_file_parser.py b/api/core/file/datasource_file_parser.py new file mode 100644 index 0000000000..52687951ac --- /dev/null +++ b/api/core/file/datasource_file_parser.py @@ -0,0 +1,15 @@ +from typing import TYPE_CHECKING, Any, cast + +from core.datasource import datasource_file_manager +from core.datasource.datasource_file_manager import DatasourceFileManager + +if TYPE_CHECKING: + from core.datasource.datasource_file_manager import DatasourceFileManager + +tool_file_manager: dict[str, Any] = {"manager": None} + + +class DatasourceFileParser: + @staticmethod + def get_datasource_file_manager() -> "DatasourceFileManager": + return cast("DatasourceFileManager", datasource_file_manager["manager"]) From d4007ae0736a1e67d3772c5eb456bd4548e0485d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 25 Apr 2025 15:49:36 +0800 Subject: [PATCH 009/155] r2 --- .../agent_tool_callback_handler.py | 31 ++ api/core/datasource/datasource_engine.py | 265 +++++------------- api/core/datasource/errors.py | 20 +- api/core/datasource/tool_file_manager.py | 234 ---------------- api/core/datasource/tool_label_manager.py | 101 ------- api/core/ops/entities/trace_entity.py | 1 + 6 files changed, 105 insertions(+), 547 deletions(-) delete mode 100644 api/core/datasource/tool_file_manager.py delete mode 100644 api/core/datasource/tool_label_manager.py diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 65d899a002..38f5e51b63 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,6 +4,7 @@ from typing import Any, Optional, TextIO, Union from pydantic import BaseModel from configs import dify_config +from core.datasource.entities.datasource_entities import DatasourceInvokeMessage from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage @@ -105,6 +106,36 @@ class DifyAgentCallbackHandler(BaseModel): self.current_loop += 1 + def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None: + """Run on datasource start.""" + if dify_config.DEBUG: + print_text("\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + + str(datasource_inputs) + "\n", color=self.color) + + def on_datasource_end(self, datasource_name: str, datasource_inputs: Mapping[str, Any], datasource_outputs: + Iterable[DatasourceInvokeMessage] | str, message_id: Optional[str] = None, + timer: Optional[Any] = None, + trace_manager: Optional[TraceQueueManager] = None) -> None: + """Run on datasource end.""" + if dify_config.DEBUG: + print_text("\n[on_datasource_end]\n", color=self.color) + print_text("Datasource: " + datasource_name + "\n", color=self.color) + print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color) + print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color) + print_text("\n") + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.DATASOURCE_TRACE, + message_id=message_id, + datasource_name=datasource_name, + datasource_inputs=datasource_inputs, + datasource_outputs=datasource_outputs, + timer=timer, + ) + ) + @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" diff --git a/api/core/datasource/datasource_engine.py b/api/core/datasource/datasource_engine.py index 423f78a787..86a3b9d0a0 100644 --- a/api/core/datasource/datasource_engine.py +++ b/api/core/datasource/datasource_engine.py @@ -1,36 +1,19 @@ import json from collections.abc import Generator, Iterable -from copy import deepcopy -from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, + DatasourceInvokeMessageBinary, +) from core.file import FileType from core.file.models import FileTransferMethod -from core.ops.ops_trace_manager import TraceQueueManager -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolInvokeMessage, - ToolInvokeMessageBinary, - ToolInvokeMeta, - ToolParameter, -) -from core.tools.errors import ( - ToolEngineInvokeError, - ToolInvokeError, - ToolNotFoundError, - ToolNotSupportedError, - ToolParameterValidationError, - ToolProviderCredentialValidationError, - ToolProviderNotFoundError, -) -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatedByRole from models.model import Message, MessageFile @@ -42,149 +25,39 @@ class DatasourceEngine: """ @staticmethod - def agent_invoke( - tool: Tool, - tool_parameters: Union[str, dict], - user_id: str, - tenant_id: str, - message: Message, - invoke_from: InvokeFrom, - agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> tuple[str, list[str], ToolInvokeMeta]: - """ - Agent invokes the tool with the given arguments. - """ - # check if arguments is a string - if isinstance(tool_parameters, str): - # check if this tool has only one parameter - parameters = [ - parameter - for parameter in tool.get_runtime_parameters() - if parameter.form == ToolParameter.ToolParameterForm.LLM - ] - if parameters and len(parameters) == 1: - tool_parameters = {parameters[0].name: tool_parameters} - else: - try: - tool_parameters = json.loads(tool_parameters) - except Exception: - pass - if not isinstance(tool_parameters, dict): - raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") - - try: - # hit the callback handler - agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) - - messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id) - invocation_meta_dict: dict[str, ToolInvokeMeta] = {} - - def message_callback( - invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] - ): - for message in messages: - if isinstance(message, ToolInvokeMeta): - invocation_meta_dict["meta"] = message - else: - yield message - - messages = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=message_callback(invocation_meta_dict, messages), - user_id=user_id, - tenant_id=tenant_id, - conversation_id=message.conversation_id, - ) - - message_list = list(messages) - - # extract binary data from tool invoke message - binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list) - # create message file - message_files = ToolEngine._create_message_files( - tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id - ) - - plain_text = ToolEngine._convert_tool_response_to_str(message_list) - - meta = invocation_meta_dict["meta"] - - # hit the callback handler - agent_tool_callback.on_tool_end( - tool_name=tool.entity.identity.name, - tool_inputs=tool_parameters, - tool_outputs=plain_text, - message_id=message.id, - trace_manager=trace_manager, - ) - - # transform tool invoke message to get LLM friendly message - return plain_text, message_files, meta - except ToolProviderCredentialValidationError as e: - error_response = "Please check your tool provider credentials" - agent_tool_callback.on_tool_error(e) - except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: - error_response = f"there is not a tool named {tool.entity.identity.name}" - agent_tool_callback.on_tool_error(e) - except ToolParameterValidationError as e: - error_response = f"tool parameters validation error: {e}, please check your tool parameters" - agent_tool_callback.on_tool_error(e) - except ToolInvokeError as e: - error_response = f"tool invoke error: {e}" - agent_tool_callback.on_tool_error(e) - except ToolEngineInvokeError as e: - meta = e.meta - error_response = f"tool invoke error: {meta.error}" - agent_tool_callback.on_tool_error(e) - return error_response, [], meta - except Exception as e: - error_response = f"unknown error: {e}" - agent_tool_callback.on_tool_error(e) - - return error_response, [], ToolInvokeMeta.error_instance(error_response) - - @staticmethod - def x( - tool: Tool, - tool_parameters: dict[str, Any], + def invoke_first_step( + datasource: DatasourcePlugin, + datasource_parameters: dict[str, Any], user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - thread_pool_id: Optional[str] = None, conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ - Workflow invokes the tool with the given arguments. + Workflow invokes the datasource with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) + workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name, + datasource_inputs=datasource_parameters) - if isinstance(tool, WorkflowTool): - tool.workflow_call_depth = workflow_call_depth + 1 - tool.thread_pool_id = thread_pool_id + if datasource.runtime and datasource.runtime.runtime_parameters: + datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters} - if tool.runtime and tool.runtime.runtime_parameters: - tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} - - response = tool.invoke( + response = datasource._invoke_first_step( user_id=user_id, - tool_parameters=tool_parameters, + datasource_parameters=datasource_parameters, conversation_id=conversation_id, app_id=app_id, message_id=message_id, ) # hit the callback handler - response = workflow_tool_callback.on_tool_execution( - tool_name=tool.entity.identity.name, - tool_inputs=tool_parameters, - tool_outputs=response, + response = workflow_tool_callback.on_datasource_end( + datasource_name=datasource.entity.identity.name, + datasource_inputs=datasource_parameters, + datasource_outputs=response, ) return response @@ -193,61 +66,49 @@ class DatasourceEngine: raise e @staticmethod - def _invoke( - tool: Tool, - tool_parameters: dict, + def invoke_second_step( + datasource: DatasourcePlugin, + datasource_parameters: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: + workflow_tool_callback: DifyWorkflowCallbackHandler, + ) -> Generator[DatasourceInvokeMessage, None, None]: """ - Invoke the tool with the given arguments. + Workflow invokes the datasource with the given arguments. """ - started_at = datetime.now(UTC) - meta = ToolInvokeMeta( - time_cost=0.0, - error=None, - tool_config={ - "tool_name": tool.entity.identity.name, - "tool_provider": tool.entity.identity.provider, - "tool_provider_type": tool.tool_provider_type().value, - "tool_parameters": deepcopy(tool.runtime.runtime_parameters), - "tool_icon": tool.entity.identity.icon, - }, - ) try: - yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id) + response = datasource._invoke_second_step( + user_id=user_id, + datasource_parameters=datasource_parameters, + ) + + return response except Exception as e: - meta.error = str(e) - raise ToolEngineInvokeError(meta) - finally: - ended_at = datetime.now(UTC) - meta.time_cost = (ended_at - started_at).total_seconds() - yield meta + workflow_tool_callback.on_tool_error(e) + raise e + @staticmethod - def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: + def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str: """ - Handle tool response + Handle datasource response """ result = "" - for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.TEXT: - result += cast(ToolInvokeMessage.TextMessage, response.message).text - elif response.type == ToolInvokeMessage.MessageType.LINK: + for response in datasource_response: + if response.type == DatasourceInvokeMessage.MessageType.TEXT: + result += cast(DatasourceInvokeMessage.TextMessage, response.message).text + elif response.type == DatasourceInvokeMessage.MessageType.LINK: result += ( - f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}." + f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}." + " please tell user to check it." ) - elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, " + "you do not need to create it, just tell the user to check it now." ) - elif response.type == ToolInvokeMessage.MessageType.JSON: + elif response.type == DatasourceInvokeMessage.MessageType.JSON: result = json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False ) else: result += str(response.message) @@ -255,14 +116,14 @@ class DatasourceEngine: return result @staticmethod - def _extract_tool_response_binary_and_text( - tool_response: list[ToolInvokeMessage], - ) -> Generator[ToolInvokeMessageBinary, None, None]: + def _extract_datasource_response_binary_and_text( + datasource_response: list[DatasourceInvokeMessage], + ) -> Generator[DatasourceInvokeMessageBinary, None, None]: """ - Extract tool response binary + Extract datasource response binary """ - for response in tool_response: - if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + for response in datasource_response: + if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}: mimetype = None if not response.meta: raise ValueError("missing meta data") @@ -270,7 +131,7 @@ class DatasourceEngine: mimetype = response.meta.get("mime_type") else: try: - url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) + url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: @@ -281,31 +142,31 @@ class DatasourceEngine: if not mimetype: mimetype = "image/jpeg" - yield ToolInvokeMessageBinary( + yield DatasourceInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), - url=cast(ToolInvokeMessage.TextMessage, response.message).text, + url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, ) - elif response.type == ToolInvokeMessage.MessageType.BLOB: + elif response.type == DatasourceInvokeMessage.MessageType.BLOB: if not response.meta: raise ValueError("missing meta data") - yield ToolInvokeMessageBinary( + yield DatasourceInvokeMessageBinary( mimetype=response.meta.get("mime_type", "application/octet-stream"), - url=cast(ToolInvokeMessage.TextMessage, response.message).text, + url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, ) - elif response.type == ToolInvokeMessage.MessageType.LINK: + elif response.type == DatasourceInvokeMessage.MessageType.LINK: # check if there is a mime type in meta if response.meta and "mime_type" in response.meta: - yield ToolInvokeMessageBinary( + yield DatasourceInvokeMessageBinary( mimetype=response.meta.get("mime_type", "application/octet-stream") if response.meta else "application/octet-stream", - url=cast(ToolInvokeMessage.TextMessage, response.message).text, + url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, ) @staticmethod def _create_message_files( - tool_messages: Iterable[ToolInvokeMessageBinary], + datasource_messages: Iterable[DatasourceInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str, @@ -317,7 +178,7 @@ class DatasourceEngine: """ result = [] - for message in tool_messages: + for message in datasource_messages: if "image" in message.mimetype: file_type = FileType.IMAGE elif "video" in message.mimetype: diff --git a/api/core/datasource/errors.py b/api/core/datasource/errors.py index c5f9ca4774..c7fc2f85b9 100644 --- a/api/core/datasource/errors.py +++ b/api/core/datasource/errors.py @@ -1,36 +1,36 @@ -from core.tools.entities.tool_entities import ToolInvokeMeta +from core.datasource.entities.datasource_entities import DatasourceInvokeMeta -class ToolProviderNotFoundError(ValueError): +class DatasourceProviderNotFoundError(ValueError): pass -class ToolNotFoundError(ValueError): +class DatasourceNotFoundError(ValueError): pass -class ToolParameterValidationError(ValueError): +class DatasourceParameterValidationError(ValueError): pass -class ToolProviderCredentialValidationError(ValueError): +class DatasourceProviderCredentialValidationError(ValueError): pass -class ToolNotSupportedError(ValueError): +class DatasourceNotSupportedError(ValueError): pass -class ToolInvokeError(ValueError): +class DatasourceInvokeError(ValueError): pass -class ToolApiSchemaError(ValueError): +class DatasourceApiSchemaError(ValueError): pass -class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta +class DatasourceEngineInvokeError(Exception): + meta: DatasourceInvokeMeta def __init__(self, meta, **kwargs): self.meta = meta diff --git a/api/core/datasource/tool_file_manager.py b/api/core/datasource/tool_file_manager.py deleted file mode 100644 index 7e8d4280d4..0000000000 --- a/api/core/datasource/tool_file_manager.py +++ /dev/null @@ -1,234 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -import os -import time -from mimetypes import guess_extension, guess_type -from typing import Optional, Union -from uuid import uuid4 - -import httpx - -from configs import dify_config -from core.helper import ssrf_proxy -from extensions.ext_database import db -from extensions.ext_storage import storage -from models.model import MessageFile -from models.tools import ToolFile - -logger = logging.getLogger(__name__) - - -class ToolFileManager: - @staticmethod - def sign_file(tool_file_id: str, extension: str) -> str: - """ - sign file to get a temporary url - """ - base_url = dify_config.FILES_URL - file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - - @staticmethod - def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - """ - data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT - - @staticmethod - def create_file_by_raw( - *, - user_id: str, - tenant_id: str, - conversation_id: Optional[str], - file_binary: bytes, - mimetype: str, - filename: Optional[str] = None, - ) -> ToolFile: - extension = guess_extension(mimetype) or ".bin" - unique_name = uuid4().hex - unique_filename = f"{unique_name}{extension}" - # default just as before - present_filename = unique_filename - if filename is not None: - has_extension = len(filename.split(".")) > 1 - # Add extension flexibly - present_filename = filename if has_extension else f"{filename}{extension}" - filepath = f"tools/{tenant_id}/{unique_filename}" - storage.save(filepath, file_binary) - - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - name=present_filename, - size=len(file_binary), - ) - - db.session.add(tool_file) - db.session.commit() - db.session.refresh(tool_file) - - return tool_file - - @staticmethod - def create_file_by_url( - user_id: str, - tenant_id: str, - file_url: str, - conversation_id: Optional[str] = None, - ) -> ToolFile: - # try to download image - try: - response = ssrf_proxy.get(file_url) - response.raise_for_status() - blob = response.content - except httpx.TimeoutException: - raise ValueError(f"timeout when downloading file from {file_url}") - - mimetype = ( - guess_type(file_url)[0] - or response.headers.get("Content-Type", "").split(";")[0].strip() - or "application/octet-stream" - ) - extension = guess_extension(mimetype) or ".bin" - unique_name = uuid4().hex - filename = f"{unique_name}{extension}" - filepath = f"tools/{tenant_id}/{filename}" - storage.save(filepath, blob) - - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - original_url=file_url, - name=filename, - size=len(blob), - ) - - db.session.add(tool_file) - db.session.commit() - - return tool_file - - @staticmethod - def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: - """ - get file binary - - :param id: the id of the file - - :return: the binary of the file, mime type - """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == id, - ) - .first() - ) - - if not tool_file: - return None - - blob = storage.load_once(tool_file.file_key) - - return blob, tool_file.mimetype - - @staticmethod - def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: - """ - get file binary - - :param id: the id of the file - - :return: the binary of the file, mime type - """ - message_file: MessageFile | None = ( - db.session.query(MessageFile) - .filter( - MessageFile.id == id, - ) - .first() - ) - - # Check if message_file is not None - if message_file is not None: - # get tool file id - if message_file.url is not None: - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] - else: - tool_file_id = None - else: - tool_file_id = None - - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, - ) - .first() - ) - - if not tool_file: - return None - - blob = storage.load_once(tool_file.file_key) - - return blob, tool_file.mimetype - - @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str): - """ - get file binary - - :param tool_file_id: the id of the tool file - - :return: the binary of the file, mime type - """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, - ) - .first() - ) - - if not tool_file: - return None, None - - stream = storage.load_stream(tool_file.file_key) - - return stream, tool_file - - -# init tool_file_parser -from core.file.tool_file_parser import tool_file_manager - -tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/datasource/tool_label_manager.py b/api/core/datasource/tool_label_manager.py deleted file mode 100644 index 4787d7d79c..0000000000 --- a/api/core/datasource/tool_label_manager.py +++ /dev/null @@ -1,101 +0,0 @@ -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.values import default_tool_label_name_list -from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from extensions.ext_database import db -from models.tools import ToolLabelBinding - - -class ToolLabelManager: - @classmethod - def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: - """ - Filter tool labels - """ - tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] - return list(set(tool_labels)) - - @classmethod - def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): - """ - Update tool labels - """ - labels = cls.filter_tool_labels(labels) - - if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id - else: - raise ValueError("Unsupported tool type") - - # delete old labels - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() - - # insert new labels - for label in labels: - db.session.add( - ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - ) - ) - - db.session.commit() - - @classmethod - def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: - """ - Get tool labels - """ - if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id - elif isinstance(controller, BuiltinToolProviderController): - return controller.tool_labels - else: - raise ValueError("Unsupported tool type") - - labels = ( - db.session.query(ToolLabelBinding.label_name) - .filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ) - .all() - ) - - return [label.label_name for label in labels] - - @classmethod - def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: - """ - Get tools labels - - :param tool_providers: list of tool providers - - :return: dict of tool labels - :key: tool id - :value: list of tool labels - """ - if not tool_providers: - return {} - - for controller in tool_providers: - if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError("Unsupported tool type") - - provider_ids = [] - for controller in tool_providers: - assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) - provider_ids.append(controller.provider_id) - - labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() - ) - - tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} - - for label in labels: - tool_labels[label.tool_id].append(label.label_name) - - return tool_labels diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index f0e34c0cd7..be6f3c007a 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -132,3 +132,4 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + DATASOURCE_TRACE = "datasource" \ No newline at end of file From 49d1846e63ebeeea5d0d77b41d53b54592c7947e Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 28 Apr 2025 16:19:12 +0800 Subject: [PATCH 010/155] r2 --- .../datasource/utils/message_transformer.py | 76 ++-- api/core/plugin/manager/datasource.py | 3 +- .../nodes/datasource/datasource_node.py | 122 +++--- api/core/workflow/nodes/datasource/exc.py | 12 +- api/core/workflow/nodes/enums.py | 1 + .../nodes/knowledge_index/__init__.py | 3 + .../nodes/knowledge_index/entities.py | 147 +++++++ .../workflow/nodes/knowledge_index/exc.py | 22 + .../knowledge_index/knowledge_index_node.py | 154 +++++++ .../nodes/knowledge_index/template_prompts.py | 66 +++ .../nodes/knowledge_retrieval/entities.py | 1 - api/models/dataset.py | 1 + api/services/dataset_service.py | 403 ++++++++++++++++++ 13 files changed, 902 insertions(+), 109 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_index/__init__.py create mode 100644 api/core/workflow/nodes/knowledge_index/entities.py create mode 100644 api/core/workflow/nodes/knowledge_index/exc.py create mode 100644 api/core/workflow/nodes/knowledge_index/knowledge_index_node.py create mode 100644 api/core/workflow/nodes/knowledge_index/template_prompts.py diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 6fd0c201e3..a10030d93b 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -3,58 +3,58 @@ from collections.abc import Generator from mimetypes import guess_extension from typing import Optional +from core.datasource.datasource_file_manager import DatasourceFileManager +from core.datasource.entities.datasource_entities import DatasourceInvokeMessage from core.file import File, FileTransferMethod, FileType -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) -class ToolFileMessageTransformer: +class DatasourceFileMessageTransformer: @classmethod - def transform_tool_invoke_messages( + def transform_datasource_invoke_messages( cls, - messages: Generator[ToolInvokeMessage, None, None], + messages: Generator[DatasourceInvokeMessage, None, None], user_id: str, tenant_id: str, conversation_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ - Transform tool message and handle file download + Transform datasource message and handle file download """ for message in messages: - if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: + if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}: yield message - elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance( - message.message, ToolInvokeMessage.TextMessage + elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance( + message.message, DatasourceInvokeMessage.TextMessage ): # try to download image try: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - file = ToolFileManager.create_file_by_url( + file = DatasourceFileManager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, conversation_id=conversation_id, ) - url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" + url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=message.meta.copy() if message.meta is not None else {}, ) except Exception as e: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage( + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.TEXT, + message=DatasourceInvokeMessage.TextMessage( text=f"Failed to download image: {message.message.text}: {e}" ), meta=message.meta.copy() if message.meta is not None else {}, ) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage meta = message.meta or {} @@ -63,12 +63,12 @@ class ToolFileMessageTransformer: filename = meta.get("file_name", None) # if message is str, encode it to bytes - if not isinstance(message.message, ToolInvokeMessage.BlobMessage): + if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage): raise ValueError("unexpected message type") # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) - file = ToolFileManager.create_file_by_raw( + file = DatasourceFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, @@ -77,22 +77,22 @@ class ToolFileMessageTransformer: filename=filename, ) - url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) # check if file is image if "image" in mimetype: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BINARY_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.BINARY_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == DatasourceInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): @@ -100,15 +100,15 @@ class ToolFileMessageTransformer: assert file.related_id is not None url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: @@ -117,5 +117,5 @@ class ToolFileMessageTransformer: yield message @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: - return f"/files/tools/{tool_file_id}{extension or '.bin'}" + def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: + return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" diff --git a/api/core/plugin/manager/datasource.py b/api/core/plugin/manager/datasource.py index 5a6f557e4b..efb42cd259 100644 --- a/api/core/plugin/manager/datasource.py +++ b/api/core/plugin/manager/datasource.py @@ -88,13 +88,14 @@ class PluginDatasourceManager(BasePluginManager): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step", + f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages", ToolInvokeMessage, data={ "user_id": user_id, "data": { "provider": datasource_provider_id.provider_name, "datasource": datasource_name, + "credentials": credentials, "datasource_parameters": datasource_parameters, }, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 1752ba36fa..8ecf66c0d6 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -5,13 +5,13 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.datasource.datasource_engine import DatasourceEngine +from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter +from core.datasource.errors import DatasourceInvokeError +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.file import File, FileTransferMethod from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.plugin import PluginInstallationManager -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult @@ -29,11 +29,7 @@ from models.workflow import WorkflowNodeExecutionStatus from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import DatasourceNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) +from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError class DatasourceNode(BaseNode[DatasourceNodeData]): @@ -60,12 +56,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): # get datasource runtime try: - from core.tools.tool_manager import ToolManager + from core.datasource.datasource_manager import DatasourceManager - tool_runtime = ToolManager.get_workflow_tool_runtime( + datasource_runtime = DatasourceManager.get_workflow_datasource_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) - except ToolNodeError as e: + except DatasourceNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -78,14 +74,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or [] parameters = self._generate_parameters( - tool_parameters=tool_parameters, + datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, ) parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, + datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, for_log=True, @@ -95,9 +91,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, - tool_parameters=parameters, + message_stream = DatasourceEngine.generic_invoke( + datasource=datasource_runtime, + datasource_parameters=parameters, user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, @@ -105,28 +101,28 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): app_id=self.app_id, conversation_id=conversation_id.text if conversation_id else None, ) - except ToolNodeError as e: + except DatasourceNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, ) ) return try: - # convert tool messages - yield from self._transform_message(message_stream, tool_info, parameters_for_log) - except (PluginDaemonClientSideError, ToolInvokeError) as e: + # convert datasource messages + yield from self._transform_message(message_stream, datasource_info, parameters_for_log) + except (PluginDaemonClientSideError, DatasourceInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to transform tool message: {str(e)}", + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, ) ) @@ -134,9 +130,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): def _generate_parameters( self, *, - tool_parameters: Sequence[ToolParameter], + datasource_parameters: Sequence[DatasourceParameter], variable_pool: VariablePool, - node_data: ToolNodeData, + node_data: DatasourceNodeData, for_log: bool = False, ) -> dict[str, Any]: """ @@ -151,25 +147,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): Mapping[str, Any]: A dictionary containing the generated parameters. """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} + datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) + for parameter_name in node_data.datasource_parameters: + parameter = datasource_parameters_dictionary.get(parameter_name) if not parameter: result[parameter_name] = None continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + datasource_input = node_data.datasource_parameters[parameter_name] + if datasource_input.type == "variable": + variable = variable_pool.get(datasource_input.value) if variable is None: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") + raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) + elif datasource_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(datasource_input.value)) parameter_value = segment_group.log if for_log else segment_group.text else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") result[parameter_name] = parameter_value return result @@ -181,15 +177,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): def _transform_message( self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], + messages: Generator[DatasourceInvokeMessage, None, None], + datasource_info: Mapping[str, Any], parameters_for_log: dict[str, Any], ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ # transform message and handle file storage - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( messages=messages, user_id=self.user_id, tenant_id=self.tenant_id, @@ -207,11 +203,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): for message in message_stream: if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, + DatasourceInvokeMessage.MessageType.IMAGE_LINK, + DatasourceInvokeMessage.MessageType.BINARY_LINK, + DatasourceInvokeMessage.MessageType.IMAGE, }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) url = message.message.text if message.meta: @@ -238,9 +234,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] @@ -261,14 +257,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == DatasourceInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) text += message.message.text yield RunStreamChunkEvent( chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + elif message.type == DatasourceInvokeMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) if self.node_type == NodeType.AGENT: msg_metadata = message.message.json_object.pop("execution_metadata", {}) agent_execution_metadata = { @@ -277,13 +273,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): if key in NodeRunMetadataKey.__members__.values() } json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == DatasourceInvokeMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -298,13 +294,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == DatasourceInvokeMessage.MessageType.FILE: assert message.meta is not None files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) + elif message.type == DatasourceInvokeMessage.MessageType.LOG: + assert isinstance(message.message, DatasourceInvokeMessage.LogMessage) if message.message.metadata: - icon = tool_info.get("icon", "") + icon = datasource_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): manager = PluginInstallationManager() @@ -366,7 +362,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): outputs={"text": text, "files": files, "json": json, **variables}, metadata={ **agent_execution_metadata, - NodeRunMetadataKey.TOOL_INFO: tool_info, + NodeRunMetadataKey.DATASOURCE_INFO: datasource_info, NodeRunMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, @@ -379,7 +375,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ToolNodeData, + node_data: DatasourceNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -389,8 +385,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): :return: """ result = {} - for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] if input.type == "mixed": assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py index 7212e8bfc0..89980e6f45 100644 --- a/api/core/workflow/nodes/datasource/exc.py +++ b/api/core/workflow/nodes/datasource/exc.py @@ -1,16 +1,16 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" +class DatasourceNodeError(ValueError): + """Base exception for datasource node errors.""" pass -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" +class DatasourceParameterError(DatasourceNodeError): + """Exception raised for errors in datasource parameters.""" pass -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" +class DatasourceFileError(DatasourceNodeError): + """Exception raised for errors related to datasource files.""" pass diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 673d0ba049..7edc73b6ba 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -7,6 +7,7 @@ class NodeType(StrEnum): ANSWER = "answer" LLM = "llm" KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + KNOWLEDGE_INDEX = "knowledge-index" IF_ELSE = "if-else" CODE = "code" TEMPLATE_TRANSFORM = "template-transform" diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py new file mode 100644 index 0000000000..01d59b87b2 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_index_node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode"] diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py new file mode 100644 index 0000000000..a87032dba6 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -0,0 +1,147 @@ +from collections.abc import Sequence +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm.entities import VisionConfig + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + provider: str + model: str + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class RetrievalSetting(BaseModel): + """ + Retrieval Setting. + """ + search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] + top_k: int + score_threshold: Optional[float] = 0.5 + score_threshold_enabled: bool = False + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting + +class FileInfo(BaseModel): + """ + File Info. + """ + file_id: str + +class OnlineDocumentIcon(BaseModel): + """ + Document Icon. + """ + icon_url: str + icon_type: str + icon_emoji: str + +class OnlineDocumentInfo(BaseModel): + """ + Online document info. + """ + provider: str + workspace_id: str + page_id: str + page_type: str + icon: OnlineDocumentIcon + +class WebsiteInfo(BaseModel): + """ + website import info. + """ + provider: str + url: str + +class GeneralStructureChunk(BaseModel): + """ + General Structure Chunk. + """ + general_chunk: list[str] + data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] + + +class ParentChildChunk(BaseModel): + """ + Parent Child Chunk. + """ + parent_content: str + child_content: list[str] + + +class ParentChildStructureChunk(BaseModel): + """ + Parent Child Structure Chunk. + """ + parent_child_chunks: list[ParentChildChunk] + data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] + + +class KnowledgeIndexNodeData(BaseNodeData): + """ + Knowledge index Node Data. + """ + + type: str = "knowledge-index" + dataset_id: str + index_chunk_variable_selector: list[str] + chunk_structure: Literal["general", "parent-child"] + index_method: IndexMethod + retrieval_setting: RetrievalSetting + diff --git a/api/core/workflow/nodes/knowledge_index/exc.py b/api/core/workflow/nodes/knowledge_index/exc.py new file mode 100644 index 0000000000..afdde9c0c5 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/exc.py @@ -0,0 +1,22 @@ +class KnowledgeIndexNodeError(ValueError): + """Base class for KnowledgeIndexNode errors.""" + + +class ModelNotExistError(KnowledgeIndexNodeError): + """Raised when the model does not exist.""" + + +class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError): + """Raised when the model credentials are not initialized.""" + + +class ModelNotSupportedError(KnowledgeIndexNodeError): + """Raised when the model is not supported.""" + + +class ModelQuotaExceededError(KnowledgeIndexNodeError): + """Raised when the model provider quota is exceeded.""" + + +class InvalidModelTypeError(KnowledgeIndexNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py new file mode 100644 index 0000000000..543a170fa7 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -0,0 +1,154 @@ +import json +import logging +import re +import time +from collections import defaultdict +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast + +from sqlalchemy import Integer, and_, func, or_, text +from sqlalchemy import cast as sqlalchemy_cast + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.simple_prompt_transform import ModelMode +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.metadata_entities import Condition, MetadataCondition +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.variables import StringSegment +from core.variables.segments import ObjectSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.knowledge_retrieval.template_prompts import ( + METADATA_FILTER_ASSISTANT_PROMPT_1, + METADATA_FILTER_ASSISTANT_PROMPT_2, + METADATA_FILTER_COMPLETION_PROMPT, + METADATA_FILTER_SYSTEM_PROMPT, + METADATA_FILTER_USER_PROMPT_1, + METADATA_FILTER_USER_PROMPT_3, +) +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog +from models.workflow import WorkflowNodeExecutionStatus +from services.dataset_service import DatasetService +from services.feature_service import FeatureService + +from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig +from .exc import ( + InvalidModelTypeError, + KnowledgeIndexNodeError, + KnowledgeRetrievalNodeError, + ModelCredentialsNotInitializedError, + ModelNotExistError, + ModelNotSupportedError, + ModelQuotaExceededError, +) + +logger = logging.getLogger(__name__) + +default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, +} + + +class KnowledgeIndexNode(LLMNode): + _node_data_cls = KnowledgeIndexNodeData # type: ignore + _node_type = NodeType.KNOWLEDGE_INDEX + + def _run(self) -> NodeRunResult: # type: ignore + node_data = cast(KnowledgeIndexNodeData, self.node_data) + # extract variables + variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector) + if not isinstance(variable, ObjectSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not object type.", + ) + chunks = variable.value + variables = {"chunks": chunks} + if not chunks: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." + ) + # check rate limit + if self.tenant_id: + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{self.tenant_id}" + redis_client.zadd(key, {current_time: current_time}) + redis_client.zremrangebyscore(key, 0, current_time - 60000) + request_count = redis_client.zcard(key) + if request_count > knowledge_rate_limit.limit: + # add ratelimit record + rate_limit_log = RateLimitLog( + tenant_id=self.tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + db.session.add(rate_limit_log) + db.session.commit() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error="Sorry, you have reached the knowledge base request rate limit of your subscription.", + error_type="RateLimitExceeded", + ) + + # retrieve knowledge + try: + results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks) + outputs = {"result": results} + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + ) + + except KnowledgeIndexNodeError as e: + logger.warning("Error when running knowledge index node") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + # Temporary handle all exceptions from DatasetRetrieval class here. + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + + + def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any: + dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") + + DatasetService.invoke_knowledge_index( + dataset=dataset, + chunks=chunks, + index_method=node_data.index_method, + retrieval_setting=node_data.retrieval_setting, + ) + + pass diff --git a/api/core/workflow/nodes/knowledge_index/template_prompts.py b/api/core/workflow/nodes/knowledge_index/template_prompts.py new file mode 100644 index 0000000000..7abd55d798 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/template_prompts.py @@ -0,0 +1,66 @@ +METADATA_FILTER_SYSTEM_PROMPT = """ + ### Job Description', + You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value + ### Task + Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". + ### Format + The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. + ### Constraint + DO NOT include anything other than the JSON array in your response. +""" # noqa: E501 + +METADATA_FILTER_USER_PROMPT_1 = """ + { "input_text": "I want to know which company’s email address test@example.com is?", + "metadata_fields": ["filename", "email", "phone", "address"] + } +""" + +METADATA_FILTER_ASSISTANT_PROMPT_1 = """ +```json + {"metadata_map": [ + {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} + ] + } +``` +""" + +METADATA_FILTER_USER_PROMPT_2 = """ + {"input_text": "What are the movies with a score of more than 9 in 2024?", + "metadata_fields": ["name", "year", "rating", "country"]} +""" + +METADATA_FILTER_ASSISTANT_PROMPT_2 = """ +```json + {"metadata_map": [ + {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, + {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, + ]} +``` +""" + +METADATA_FILTER_USER_PROMPT_3 = """ + '{{"input_text": "{input_text}",', + '"metadata_fields": {metadata_fields}}}' +""" + +METADATA_FILTER_COMPLETION_PROMPT = """ +### Job Description +You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value +### Task +# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". +### Format +The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. +### Constraint +DO NOT include anything other than the JSON array in your response. +### Example +Here is the chat example between human and assistant, inside XML tags. + +User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} +Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} +User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} +Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} + +### User Input +{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} +### Assistant Output +""" # noqa: E501 diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index d2e5a15545..17b3308a06 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -59,7 +59,6 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): """ Model Config. - """ provider: str name: str diff --git a/api/models/dataset.py b/api/models/dataset.py index a344ab2964..3c44fb4b45 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -59,6 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) + keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) collection_binding_id = db.Column(StringUUID, nullable=True) retrieval_model = db.Column(JSONB, nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index b019cf6b63..19962d66b9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,6 +21,7 @@ from core.plugin.entities.plugin import ModelProviderID from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -1131,6 +1132,408 @@ class DocumentService: return documents, batch @staticmethod + def save_document_with_dataset_id( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account | Any, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", + ): + # check document limit + features = FeatureService.get_features(current_user.current_tenant_id) + + if features.billing.enabled: + if not knowledge_config.original_document_id: + count = 0 + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + + if features.billing.subscription.plan == "sandbox" and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) + + # if dataset is empty, update dataset data_source_type + if not dataset.data_source_type: + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + + if not dataset.indexing_technique: + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": + model_manager = ModelManager() + if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + dataset_embedding_model = knowledge_config.embedding_model + dataset_embedding_model_provider = knowledge_config.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore + + documents = [] + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + documents.append(document) + batch = document.batch + else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + # save process rule + if not dataset_process_rule: + process_rule = knowledge_config.process_rule + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if knowledge_config.duplicate: + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if not notion_info_list: + raise ValueError("No notion info list found.") + exist_page_ids = [] + exist_document = {} + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info.pages: + if page.page_id not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, + } + # Truncate page name to 255 characters to prevent DB field length errors + truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + truncated_page_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page.page_id) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." + else: + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() + + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + return documents, batch + + @staticmethod + def invoke_knowledge_index( + dataset: Dataset, + chunks: list[Any], + index_method: IndexMethod, + retrieval_setting: RetrievalSetting, + original_document_id: str | None = None, + account: Account | Any, + created_from: str = "rag-pipline", + ): + + if not dataset.indexing_technique: + if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = index_method.indexing_technique + if index_method.indexing_technique == "high_quality": + model_manager = ModelManager() + if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider: + dataset_embedding_model = index_method.embedding_setting.embedding_model + dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + retrieval_setting.model_dump() + if retrieval_setting + else default_retrieval_model + ) # type: ignore + + documents = [] + if original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + documents.append(document) + batch = document.batch + else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + for chunk in chunks: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if knowledge_config.duplicate: + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + + db.session.commit() + + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + return documents, batch + @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: From a25cc4e8afba463b8b02404e738ea92c7f7f90ea Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 6 May 2025 13:56:13 +0800 Subject: [PATCH 011/155] r2 --- .../processor/paragraph_index_processor.py | 7 +- .../knowledge_index/knowledge_index_node.py | 12 ++- api/services/dataset_service.py | 85 ++----------------- 3 files changed, 22 insertions(+), 82 deletions(-) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index dca84b9041..79c2c16b90 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,7 +1,7 @@ """Paragraph index processor.""" import uuid -from typing import Optional +from typing import Any, Mapping, Optional from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -125,3 +125,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs + + def index(self, dataset: Dataset, document: Document, chunks: list[Mapping[str, Any]]): + for chunk in chunks: + GeneralDocument.create( + pass diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 543a170fa7..7aa6b9379f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -43,7 +43,7 @@ from extensions.ext_redis import redis_client from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from models.workflow import WorkflowNodeExecutionStatus -from services.dataset_service import DatasetService +from services.dataset_service import DatasetService, DocumentService from services.feature_service import FeatureService from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig @@ -139,14 +139,20 @@ class KnowledgeIndexNode(LLMNode): ) - def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any: + def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, document_id: str, chunks: list[any]) -> Any: dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() if not dataset: raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") - DatasetService.invoke_knowledge_index( + document = Document.query.filter_by(id=document_id).first() + if not document: + raise KnowledgeIndexNodeError(f"Document {document_id} not found.") + + DocumentService.invoke_knowledge_index( dataset=dataset, + document=document, chunks=chunks, + chunk_structure=node_data.chunk_structure, index_method=node_data.index_method, retrieval_setting=node_data.retrieval_setting, ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e14a92f36..60d3ccf131 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Literal, Optional from flask_login import current_user # type: ignore from sqlalchemy import func @@ -20,6 +20,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting from events.dataset_event import dataset_was_deleted @@ -1435,9 +1436,11 @@ class DocumentService: @staticmethod def invoke_knowledge_index( dataset: Dataset, + document: Document, chunks: list[Any], index_method: IndexMethod, retrieval_setting: RetrievalSetting, + chunk_structure: Literal["text_model", "hierarchical_model"], original_document_id: str | None = None, account: Account | Any, created_from: str = "rag-pipline", @@ -1479,85 +1482,11 @@ class DocumentService: if retrieval_setting else default_retrieval_model ) # type: ignore - - documents = [] - if original_document_id: - document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) - documents.append(document) - batch = document.batch - else: - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) - - lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) - with redis_client.lock(lock_name, timeout=600): - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - for chunk in chunks: - file = ( - db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) - - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if knowledge_config.duplicate: - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() - if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - - db.session.commit() - - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + index_processor.index(dataset, document, chunks) return documents, batch + @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size From a998022c12de1f71d5d923071a2dad6b1719b79b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 6 May 2025 16:18:34 +0800 Subject: [PATCH 012/155] r2 --- .../index_processor/processor/paragraph_index_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 79c2c16b90..7c2031258e 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -12,6 +12,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.nodes.knowledge_index.entities import GeneralStructureChunk from libs import helper from models.dataset import Dataset, DatasetProcessRule from services.entities.knowledge_entities.knowledge_entities import Rule @@ -126,7 +127,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs - def index(self, dataset: Dataset, document: Document, chunks: list[Mapping[str, Any]]): - for chunk in chunks: - GeneralDocument.create( + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + paragraph = GeneralStructureChunk(**chunks) pass From 3f1363503b4ea92b1f16d30ccfedf1fac7302054 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 7 May 2025 16:19:09 +0800 Subject: [PATCH 013/155] r2 --- api/core/plugin/impl/datasource.py | 1 - .../index_processor/index_processor_base.py | 7 +- .../processor/paragraph_index_processor.py | 26 ++++++- .../processor/parent_child_index_processor.py | 35 +++++++++- .../nodes/knowledge_index/entities.py | 32 ++++++--- .../knowledge_index/knowledge_index_node.py | 68 +++++-------------- api/services/dataset_service.py | 20 +++--- 7 files changed, 116 insertions(+), 73 deletions(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 029f752d25..c69fa2fe32 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -95,7 +95,6 @@ class PluginDatasourceManager(BasePluginClient): "data": { "provider": datasource_provider_id.provider_name, "datasource": datasource_name, - "credentials": credentials, "datasource_parameters": datasource_parameters, }, diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 2bcd1c79bb..d796c9fd24 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,7 +1,8 @@ """Abstract interface for document loader implementations.""" from abc import ABC, abstractmethod -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from configs import dify_config from core.model_manager import ModelInstance @@ -33,6 +34,10 @@ class BaseIndexProcessor(ABC): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): raise NotImplementedError + @abstractmethod + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + raise NotImplementedError + @abstractmethod def retrieve( self, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 7c2031258e..43d201af73 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,12 +1,14 @@ """Paragraph index processor.""" import uuid -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor @@ -129,4 +131,24 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): paragraph = GeneralStructureChunk(**chunks) - pass + documents = [] + for content in paragraph.general_chunk: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(content), + } + doc = Document(page_content=content, metadata=metadata) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 1cde5e1c8f..ce64bb2a54 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,17 +1,20 @@ """Paragraph index processor.""" import uuid -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from configs import dify_config from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import ChildDocument, Document +from core.workflow.nodes.knowledge_index.entities import ParentChildStructureChunk from extensions.ext_database import db from libs import helper from models.dataset import ChildChunk, Dataset, DocumentSegment @@ -202,3 +205,33 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_document.page_content = child_page_content child_nodes.append(child_document) return child_nodes + + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + parent_childs = ParentChildStructureChunk(**chunks) + documents = [] + for parent_child in parent_childs.parent_child_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(parent_child.parent_content), + } + child_documents = [] + for child in parent_child.child_contents: + child_metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(child), + } + child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) + doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=True) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index a87032dba6..635748799b 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,10 +1,8 @@ -from collections.abc import Sequence -from typing import Any, Literal, Optional, Union +from typing import Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import VisionConfig class RerankingModelConfig(BaseModel): @@ -15,6 +13,7 @@ class RerankingModelConfig(BaseModel): provider: str model: str + class VectorSetting(BaseModel): """ Vector Setting. @@ -32,6 +31,7 @@ class KeywordSetting(BaseModel): keyword_weight: float + class WeightedScoreConfig(BaseModel): """ Weighted score Config. @@ -45,6 +45,7 @@ class EmbeddingSetting(BaseModel): """ Embedding Setting. """ + embedding_provider_name: str embedding_model_name: str @@ -61,6 +62,7 @@ class RetrievalSetting(BaseModel): """ Retrieval Setting. """ + search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] top_k: int score_threshold: Optional[float] = 0.5 @@ -70,49 +72,61 @@ class RetrievalSetting(BaseModel): reranking_model: Optional[RerankingModelConfig] = None weights: Optional[WeightedScoreConfig] = None + class IndexMethod(BaseModel): """ Knowledge Index Setting. """ + indexing_technique: Literal["high_quality", "economy"] embedding_setting: EmbeddingSetting economy_setting: EconomySetting + class FileInfo(BaseModel): """ File Info. """ + file_id: str + class OnlineDocumentIcon(BaseModel): """ Document Icon. """ + icon_url: str icon_type: str icon_emoji: str + class OnlineDocumentInfo(BaseModel): """ Online document info. """ + provider: str workspace_id: str page_id: str page_type: str icon: OnlineDocumentIcon + class WebsiteInfo(BaseModel): """ website import info. """ - provider: str + + provider: str url: str + class GeneralStructureChunk(BaseModel): """ General Structure Chunk. """ + general_chunk: list[str] data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] @@ -121,14 +135,16 @@ class ParentChildChunk(BaseModel): """ Parent Child Chunk. """ + parent_content: str - child_content: list[str] + child_contents: list[str] class ParentChildStructureChunk(BaseModel): """ Parent Child Structure Chunk. """ + parent_child_chunks: list[ParentChildChunk] data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] @@ -138,10 +154,10 @@ class KnowledgeIndexNodeData(BaseNodeData): Knowledge index Node Data. """ - type: str = "knowledge-index" + type: str = "knowledge-index" dataset_id: str + document_id: str index_chunk_variable_selector: list[str] chunk_structure: Literal["general", "parent-child"] index_method: IndexMethod retrieval_setting: RetrievalSetting - diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 7aa6b9379f..5f9ac78097 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,60 +1,22 @@ -import json import logging -import re import time -from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast -from sqlalchemy import Integer, and_, func, or_, text -from sqlalchemy import cast as sqlalchemy_cast - -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.agent_entities import PlanningStrategy -from core.entities.model_entities import ModelStatus -from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.simple_prompt_transform import ModelMode -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables import StringSegment from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import ModelInvokeCompletedEvent -from core.workflow.nodes.knowledge_retrieval.template_prompts import ( - METADATA_FILTER_ASSISTANT_PROMPT_1, - METADATA_FILTER_ASSISTANT_PROMPT_2, - METADATA_FILTER_COMPLETION_PROMPT, - METADATA_FILTER_SYSTEM_PROMPT, - METADATA_FILTER_USER_PROMPT_1, - METADATA_FILTER_USER_PROMPT_3, -) -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog +from models.dataset import Dataset, Document, RateLimitLog from models.workflow import WorkflowNodeExecutionStatus -from services.dataset_service import DatasetService, DocumentService +from services.dataset_service import DocumentService from services.feature_service import FeatureService -from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig +from .entities import KnowledgeIndexNodeData from .exc import ( - InvalidModelTypeError, KnowledgeIndexNodeError, - KnowledgeRetrievalNodeError, - ModelCredentialsNotInitializedError, - ModelNotExistError, - ModelNotSupportedError, - ModelQuotaExceededError, ) logger = logging.getLogger(__name__) @@ -138,16 +100,15 @@ class KnowledgeIndexNode(LLMNode): error_type=type(e).__name__, ) - - def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, document_id: str, chunks: list[any]) -> Any: + def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[Any]) -> Any: dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() if not dataset: raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") - - document = Document.query.filter_by(id=document_id).first() + + document = Document.query.filter_by(id=node_data.document_id).first() if not document: - raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - + raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") + DocumentService.invoke_knowledge_index( dataset=dataset, document=document, @@ -156,5 +117,12 @@ class KnowledgeIndexNode(LLMNode): index_method=node_data.index_method, retrieval_setting=node_data.retrieval_setting, ) - - pass + + return { + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "created_at": document.created_at, + "display_status": document.indexing_status, + } diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 60d3ccf131..af1c1028cf 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1441,11 +1441,7 @@ class DocumentService: index_method: IndexMethod, retrieval_setting: RetrievalSetting, chunk_structure: Literal["text_model", "hierarchical_model"], - original_document_id: str | None = None, - account: Account | Any, - created_from: str = "rag-pipline", ): - if not dataset.indexing_technique: if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") @@ -1453,7 +1449,10 @@ class DocumentService: dataset.indexing_technique = index_method.indexing_technique if index_method.indexing_technique == "high_quality": model_manager = ModelManager() - if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider: + if ( + index_method.embedding_setting.embedding_model + and index_method.embedding_setting.embedding_model_provider + ): dataset_embedding_model = index_method.embedding_setting.embedding_model dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider else: @@ -1478,15 +1477,16 @@ class DocumentService: } dataset.retrieval_model = ( - retrieval_setting.model_dump() - if retrieval_setting - else default_retrieval_model + retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model ) # type: ignore index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) - return documents, batch - + # update document status + document.indexing_status = "completed" + document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size From 818eb46a8b02908b9310ac7208836bab994977d1 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 15 May 2025 15:14:52 +0800 Subject: [PATCH 014/155] r2 --- .../rag_pipeline/rag_pipeline_datasets.py | 170 ++++ .../rag_pipeline/rag_pipeline_import.py | 147 +++ .../rag_pipeline/rag_pipeline_workflow.py | 159 +++- .../datasource/__base/datasource_plugin.py | 6 +- .../datasource/__base/datasource_provider.py | 18 +- api/core/datasource/entities/api_entities.py | 73 ++ api/core/plugin/entities/plugin_daemon.py | 8 + api/core/plugin/impl/tool.py | 67 +- api/core/tools/tool_manager.py | 26 + .../nodes/knowledge_index/entities.py | 4 +- api/fields/dataset_fields.py | 6 + api/fields/rag_pipeline_fields.py | 163 ++++ api/models/dataset.py | 4 + api/models/tools.py | 34 + api/services/dataset_service.py | 59 ++ .../rag_pipeline_entities.py | 99 ++- api/services/rag_pipeline/rag_pipeline.py | 207 +++-- .../rag_pipeline/rag_pipeline_dsl_service.py | 841 ++++++++++++++++++ .../tools/builtin_tools_manage_service.py | 64 +- api/services/tools/tools_transform_service.py | 110 ++- api/services/workflow_service.py | 1 - 21 files changed, 2117 insertions(+), 149 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py create mode 100644 api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py create mode 100644 api/core/datasource/entities/api_entities.py create mode 100644 api/fields/rag_pipeline_fields.py create mode 100644 api/services/rag_pipeline/rag_pipeline_dsl_service.py diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py new file mode 100644 index 0000000000..6676deb63a --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -0,0 +1,170 @@ +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore +from werkzeug.exceptions import Forbidden + +import services +from controllers.console import api +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + setup_required, +) +from fields.dataset_fields import dataset_detail_fields +from libs.login import login_required +from models.dataset import DatasetPermissionEnum +from services.dataset_service import DatasetPermissionService, DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class CreateRagPipelineDatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + + parser.add_argument( + "icon_info", + type=dict, + nullable=True, + required=False, + default={}, + ) + + parser.add_argument( + "permission", + type=str, + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + nullable=True, + required=False, + default=DatasetPermissionEnum.ONLY_ME, + ) + + parser.add_argument( + "partial_member_list", + type=list, + nullable=True, + required=False, + default=[], + ) + + parser.add_argument( + "yaml_content", + type=str, + nullable=False, + required=True, + help="yaml_content is required.", + ) + + args = parser.parse_args() + + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args) + try: + import_info = DatasetService.create_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, + ) + if rag_pipeline_dataset_create_entity.permission == "partial_members": + DatasetPermissionService.update_partial_member_list( + current_user.current_tenant_id, + import_info["dataset_id"], + rag_pipeline_dataset_create_entity.partial_member_list, + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return import_info, 201 + + +class CreateEmptyRagPipelineDatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self): + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + + parser.add_argument( + "icon_info", + type=dict, + nullable=True, + required=False, + default={}, + ) + + parser.add_argument( + "permission", + type=str, + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + nullable=True, + required=False, + default=DatasetPermissionEnum.ONLY_ME, + ) + + parser.add_argument( + "partial_member_list", + type=list, + nullable=True, + required=False, + default=[], + ) + + args = parser.parse_args() + dataset = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=args, + ) + return marshal(dataset, dataset_detail_fields), 201 + + +api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset") +api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py new file mode 100644 index 0000000000..853aef2e09 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -0,0 +1,147 @@ +from typing import cast + +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from extensions.ext_database import db +from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from libs.login import login_required +from models import Account +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + +class RagPipelineImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(pipeline_import_fields) + def post(self): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("mode", type=str, required=True, location="json") + parser.add_argument("yaml_content", type=str, location="json") + parser.add_argument("yaml_url", type=str, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("pipeline_id", type=str, location="json") + args = parser.parse_args() + + # Create service with session + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + # Import app + account = cast(Account, current_user) + result = import_service.import_rag_pipeline( + account=account, + import_mode=args["mode"], + yaml_content=args.get("yaml_content"), + yaml_url=args.get("yaml_url"), + pipeline_id=args.get("pipeline_id"), + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING.value: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +class RagPipelineImportConfirmApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(pipeline_import_fields) + def post(self, import_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + # Create service with session + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + # Confirm import + account = cast(Account, current_user) + result = import_service.confirm_import(import_id=import_id, account=account) + session.commit() + + # Return appropriate status code based on result + if result.status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 + + +class RagPipelineImportCheckDependenciesApi(Resource): + @setup_required + @login_required + @get_rag_pipeline + @account_initialization_required + @marshal_with(pipeline_import_check_dependencies_fields) + def get(self, pipeline: Pipeline): + if not current_user.is_editor: + raise Forbidden() + + with Session(db.engine) as session: + import_service = RagPipelineDslService(session) + result = import_service.check_dependencies(pipeline=pipeline) + + return result.model_dump(mode="json"), 200 + + +class RagPipelineExportApi(Resource): + @setup_required + @login_required + @get_rag_pipeline + @account_initialization_required + @marshal_with(pipeline_import_check_dependencies_fields) + def get(self, pipeline: Pipeline): + if not current_user.is_editor: + raise Forbidden() + + # Add include_secret params + parser = reqparse.RequestParser() + parser.add_argument("include_secret", type=bool, default=False, location="args") + args = parser.parse_args() + + with Session(db.engine) as session: + export_service = RagPipelineDslService(session) + result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"]) + + return {"data": result}, 200 + + +# Import Rag Pipeline +api.add_resource( + RagPipelineImportApi, + "/rag/pipelines/imports", +) +api.add_resource( + RagPipelineImportConfirmApi, + "/rag/pipelines/imports//confirm", +) +api.add_resource( + RagPipelineImportCheckDependenciesApi, + "/rag/pipelines/imports//check-dependencies", +) +api.add_resource( + RagPipelineExportApi, + "/rag/pipelines//exports", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index e67b3c0657..99d3b73d33 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import cast from flask import abort, request from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -23,12 +24,18 @@ from controllers.console.wraps import ( from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from fields.workflow_run_fields import workflow_run_node_execution_fields +from fields.workflow_run_fields import ( + workflow_run_detail_fields, + workflow_run_node_execution_fields, + workflow_run_node_execution_list_fields, + workflow_run_pagination_fields, +) from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline @@ -36,6 +43,7 @@ from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) @@ -461,45 +469,6 @@ class DefaultRagPipelineBlockConfigApi(Resource): rag_pipeline_service = RagPipelineService() return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) - -class ConvertToRagPipelineApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def post(self, pipeline: Pipeline): - """ - Convert basic mode of chatbot app to workflow mode - Convert expert mode of chatbot app to workflow mode - Convert Completion App to Workflow App - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - if request.data: - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon", type=str, required=False, nullable=True, location="json") - parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") - args = parser.parse_args() - else: - args = {} - - # convert to workflow mode - rag_pipeline_service = RagPipelineService() - new_app_model = rag_pipeline_service.convert_to_workflow(pipeline=pipeline, account=current_user, args=args) - - # return app id - return { - "new_app_id": new_app_model.id, - } - - class RagPipelineConfigApi(Resource): """Resource for rag pipeline configuration.""" @@ -674,6 +643,85 @@ class RagPipelineSecondStepApi(Resource): ) +class RagPipelineWorkflowRunListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_pagination_fields) + def get(self, pipeline: Pipeline): + """ + Get workflow run list + """ + parser = reqparse.RequestParser() + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + args = parser.parse_args() + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) + + return result + + +class RagPipelineWorkflowRunDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_detail_fields) + def get(self, pipeline: Pipeline, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + rag_pipeline_service = RagPipelineService() + workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id) + + return workflow_run + + +class RagPipelineWorkflowRunNodeExecutionListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_node_execution_list_fields) + def get(self, pipeline: Pipeline, run_id): + """ + Get workflow run node execution list + """ + run_id = str(run_id) + + rag_pipeline_service = RagPipelineService() + node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( + pipeline=pipeline, + run_id=run_id, + ) + + return {"data": node_executions} + + +class DatasourceListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user = current_user + + tenant_id = user.current_tenant_id + + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_rag_pipeline_datasources( + tenant_id, + ) + ] + ) + + api.add_resource( DraftRagPipelineApi, "/rag/pipelines//workflows/draft", @@ -694,10 +742,10 @@ api.add_resource( RagPipelineDraftNodeRunApi, "/rag/pipelines//workflows/draft/nodes//run", ) -api.add_resource( - RagPipelinePublishedNodeRunApi, - "/rag/pipelines//workflows/published/nodes//run", -) +# api.add_resource( +# RagPipelinePublishedNodeRunApi, +# "/rag/pipelines//workflows/published/nodes//run", +# ) api.add_resource( RagPipelineDraftRunIterationNodeApi, @@ -724,11 +772,24 @@ api.add_resource( DefaultRagPipelineBlockConfigApi, "/rag/pipelines//workflows/default-workflow-block-configs/", ) -api.add_resource( - ConvertToRagPipelineApi, - "/rag/pipelines//convert-to-workflow", -) + api.add_resource( RagPipelineByIdApi, "/rag/pipelines//workflows/", ) +api.add_resource( + RagPipelineWorkflowRunListApi, + "/rag/pipelines//workflow-runs", +) +api.add_resource( + RagPipelineWorkflowRunDetailApi, + "/rag/pipelines//workflow-runs/", +) +api.add_resource( + RagPipelineWorkflowRunNodeExecutionListApi, + "/rag/pipelines//workflow-runs//node-executions", +) +api.add_resource( + DatasourceListApi, + "/rag/pipelines/datasources", +) diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 991bceb422..86bd66a3f9 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -2,13 +2,13 @@ from collections.abc import Generator from typing import Any, Optional from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, ) -from core.plugin.manager.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -44,7 +44,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = PluginDatasourceManager() + manager = DatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) @@ -64,7 +64,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = PluginDatasourceManager() + manager = DatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index e9efb7b9dc..ef3382b948 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -4,12 +4,11 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.entities.provider_entities import ProviderConfig -from core.plugin.manager.tool import PluginToolManager -from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.plugin.impl.tool import PluginToolManager from core.tools.errors import ToolProviderCredentialValidationError -class DatasourcePluginProviderController(BuiltinToolProviderController): +class DatasourcePluginProviderController: entity: DatasourceProviderEntityWithPlugin tenant_id: str plugin_id: str @@ -32,12 +31,21 @@ class DatasourcePluginProviderController(BuiltinToolProviderController): """ return DatasourceProviderType.RAG_PIPELINE + @property + def need_credentials(self) -> bool: + """ + returns whether the provider needs credentials + + :return: whether the provider needs credentials + """ + return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider """ manager = PluginToolManager() - if not manager.validate_provider_credentials( + if not manager.validate_datasource_credentials( tenant_id=self.tenant_id, user_id=user_id, provider=self.entity.identity.name, @@ -69,7 +77,7 @@ class DatasourcePluginProviderController(BuiltinToolProviderController): plugin_unique_identifier=self.plugin_unique_identifier, ) - def get_datasources(self) -> list[DatasourceTool]: # type: ignore + def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore """ get all datasources """ diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py new file mode 100644 index 0000000000..2d42484a30 --- /dev/null +++ b/api/core/datasource/entities/api_entities.py @@ -0,0 +1,73 @@ +from typing import Literal, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.datasource.entities.datasource_entities import DatasourceParameter +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool import ToolParameter +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType + + +class DatasourceApiEntity(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[list[ToolParameter]] = None + labels: list[str] = Field(default_factory=list) + output_schema: Optional[dict] = None + + +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] + + +class DatasourceProviderApiEntity(BaseModel): + id: str + author: str + name: str # identifier + description: I18nObject + icon: str | dict + label: I18nObject # label + type: ToolProviderType + masked_credentials: Optional[dict] = None + original_credentials: Optional[dict] = None + is_team_authorization: bool = False + allow_delete: bool = True + plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + datasources: list[DatasourceApiEntity] = Field(default_factory=list) + labels: list[str] = Field(default_factory=list) + + @field_validator("datasources", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + + def to_dict(self) -> dict: + # ------------- + # overwrite datasource parameter types for temp fix + datasources = jsonable_encoder(self.datasources) + for datasource in datasources: + if datasource.get("parameters"): + for parameter in datasource.get("parameters"): + if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: + parameter["type"] = "files" + # ------------- + + return { + "id": self.id, + "author": self.author, + "name": self.name, + "plugin_id": self.plugin_id, + "plugin_unique_identifier": self.plugin_unique_identifier, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "datasources": datasources, + "labels": self.labels, + } diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 1588cbc3c7..40e753671c 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -5,6 +5,7 @@ from typing import Generic, Optional, TypeVar from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity @@ -46,6 +47,13 @@ class PluginToolProviderEntity(BaseModel): declaration: ToolProviderEntityWithPlugin +class PluginDatasourceProviderEntity(BaseModel): + provider: str + plugin_unique_identifier: str + plugin_id: str + declaration: DatasourceProviderEntityWithPlugin + + class PluginAgentProviderEntity(BaseModel): provider: str plugin_unique_identifier: str diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..f4360a70de 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -4,7 +4,11 @@ from typing import Any, Optional from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.entities.plugin_daemon import ( + PluginBasicBooleanResponse, + PluginDatasourceProviderEntity, + PluginToolProviderEntity, +) from core.plugin.impl.base import BasePluginClient from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -41,6 +45,37 @@ class PluginToolManager(BasePluginClient): return response + def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: + """ + Fetch datasources for the given tenant. + """ + + def transformer(json_response: dict[str, Any]) -> dict: + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for tool in declaration.get("tools", []): + tool["identity"]["provider"] = provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + list[PluginToolProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + + for provider in response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.tools: + tool.identity.provider = provider.declaration.identity.name + + return response + def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: """ Fetch tool provider for the given tenant and plugin. @@ -197,6 +232,36 @@ class PluginToolManager(BasePluginClient): return False + def validate_datasource_credentials( + self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + ) -> bool: + """ + validate the credentials of the datasource + """ + tool_provider_id = GenericProviderID(provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/validate_credentials", + PluginBasicBooleanResponse, + data={ + "user_id": user_id, + "data": { + "provider": tool_provider_id.provider_name, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": tool_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp.result + + return False + def get_runtime_parameters( self, tenant_id: str, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index aa2661fe63..682a32d26f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -495,6 +496,31 @@ class ToolManager: # get plugin providers yield from cls.list_plugin_providers(tenant_id) + @classmethod + def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: + """ + list all the datasource providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_datasources(tenant_id) + return [ + DatasourcePluginProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + for provider in provider_entities + ] + + @classmethod + def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]: + """ + list all the builtin datasources + """ + # get builtin datasources + yield from cls.list_datasource_providers(tenant_id) + @classmethod def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 635748799b..05661a6cc8 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -10,8 +10,8 @@ class RerankingModelConfig(BaseModel): Reranking Model Config. """ - provider: str - model: str + reranking_provider_name: str + reranking_model_name: str class VectorSetting(BaseModel): diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 67d183c70d..9d34734af7 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -56,6 +56,8 @@ external_knowledge_info_fields = { doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} +icon_info_fields = {"icon_type": fields.String, "icon": fields.String, "icon_background": fields.String} + dataset_detail_fields = { "id": fields.String, "name": fields.String, @@ -81,6 +83,10 @@ dataset_detail_fields = { "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)), "built_in_field_enabled": fields.Boolean, + "pipeline_id": fields.String, + "runtime_mode": fields.String, + "chunk_structure": fields.String, + "icon_info": fields.Nested(icon_info_fields), } dataset_query_detail_fields = { diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py new file mode 100644 index 0000000000..0bb74e3259 --- /dev/null +++ b/api/fields/rag_pipeline_fields.py @@ -0,0 +1,163 @@ +from flask_restful import fields # type: ignore + +from fields.workflow_fields import workflow_partial_fields +from libs.helper import AppIconUrlField, TimestampField + +pipeline_detail_kernel_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, +} + +related_app_list = { + "data": fields.List(fields.Nested(pipeline_detail_kernel_fields)), + "total": fields.Integer, +} + +app_detail_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "tracing": fields.Raw, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + + +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + +app_partial_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String(attribute="desc_or_prompt"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), +} + + +app_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), +} + +template_fields = { + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, +} + +template_list_fields = { + "data": fields.List(fields.Nested(template_fields)), +} + +site_fields = { + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +deleted_tool_fields = { + "type": fields.String, + "tool_name": fields.String, + "provider_id": fields.String, +} + +app_detail_fields_with_site = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + + +app_site_fields = { + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, +} + +leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String} + +pipeline_import_fields = { + "id": fields.String, + "status": fields.String, + "pipeline_id": fields.String, + "current_dsl_version": fields.String, + "imported_dsl_version": fields.String, + "error": fields.String, +} + +pipeline_import_check_dependencies_fields = { + "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)), +} diff --git a/api/models/dataset.py b/api/models/dataset.py index 3c44fb4b45..6d23973bba 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -63,6 +63,10 @@ class Dataset(db.Model): # type: ignore[name-defined] collection_binding_id = db.Column(StringUUID, nullable=True) retrieval_model = db.Column(JSONB, nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + icon_info = db.Column(JSONB, nullable=True) + runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) + pipeline_id = db.Column(StringUUID, nullable=True) + chunk_structure = db.Column(db.String(255), nullable=True) @property def dataset_keyword_table(self): diff --git a/api/models/tools.py b/api/models/tools.py index aef1490729..6d08ba61aa 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -51,6 +51,40 @@ class BuiltinToolProvider(Base): return cast(dict, json.loads(self.encrypted_credentials)) +class BuiltinDatasourceProvider(Base): + """ + This table stores the datasource provider information for built-in datasources for each tenant. + """ + + __tablename__ = "tool_builtin_datasource_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_builtin_datasource_provider_pkey"), + # one tenant can only have one tool provider with the same name + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_datasource_provider"), + ) + + # id of the tool provider + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # id of the tenant + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) + # who created this tool provider + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # name of the tool provider + provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + # credential of the tool provider + encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + + @property + def credentials(self) -> dict: + return cast(dict, json.loads(self.encrypted_credentials)) + + class ApiToolProvider(Base): """ The table stores the api providers. diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index af1c1028cf..42748dbf96 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -52,6 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) +from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity from services.errors.account import InvalidActionError, NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError @@ -59,6 +60,7 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService +from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo from services.tag_service import TagService from services.vector_service import VectorService from tasks.batch_clean_document_task import batch_clean_document_task @@ -235,6 +237,63 @@ class DatasetService: db.session.commit() return dataset + @staticmethod + def create_empty_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + dataset = Dataset( + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag_pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + dataset = Dataset( + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag_pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info, + ) + + if rag_pipeline_dataset_create_entity.yaml_content: + rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline( + current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset + ) + return { + "id": rag_pipeline_import_info.id, + "dataset_id": dataset.id, + "pipeline_id": rag_pipeline_import_info.pipeline_id, + "status": rag_pipeline_import_info.status, + "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, + "current_dsl_version": rag_pipeline_import_info.current_dsl_version, + "error": rag_pipeline_import_info.error, + } + @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index d59d47bbce..5f581f1360 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel @@ -14,3 +14,100 @@ class PipelineTemplateInfoEntity(BaseModel): name: str description: str icon_info: IconInfo + + +class RagPipelineDatasetCreateEntity(BaseModel): + name: str + description: str + icon_info: IconInfo + permission: str + partial_member_list: list[str] + yaml_content: str + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + reranking_provider_name: str + reranking_model_name: str + + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class RetrievalSetting(BaseModel): + """ + Retrieval Setting. + """ + + search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] + top_k: int + score_threshold: Optional[float] = 0.5 + score_threshold_enabled: bool = False + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None + + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting + + +class KnowledgeConfiguration(BaseModel): + """ + Knowledge Configuration. + """ + + chunk_structure: str + index_method: IndexMethod + retrieval_setting: RetrievalSetting diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 79f6e79cf5..1e6447d80f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,28 +1,45 @@ import json +import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime from typing import Any, Literal, Optional +from uuid import uuid4 from flask_login import current_user from sqlalchemy import select from sqlalchemy.orm import Session +import contexts from configs import dify_config +from core.model_runtime.utils.encoders import jsonable_encoder +from core.repository.repository_factory import RepositoryFactory +from core.repository.workflow_node_execution_repository import OrderConfig from core.variables.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType +from models.enums import CreatedByRole, WorkflowRunTriggeredFrom +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowType, +) from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError -from services.errors.workflow_service import DraftWorkflowDeletionError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -180,7 +197,6 @@ class RagPipelineService: *, pipeline: Pipeline, graph: dict, - features: dict, unique_hash: Optional[str], account: Account, environment_variables: Sequence[Variable], @@ -197,9 +213,6 @@ class RagPipelineService: if workflow and workflow.unique_hash != unique_hash: raise WorkflowHashNotEqualError() - # validate features structure - self.validate_features_structure(pipeline=pipeline, features=features) - # create draft workflow if not found if not workflow: workflow = Workflow( @@ -208,7 +221,6 @@ class RagPipelineService: type=WorkflowType.RAG_PIPELINE.value, version="draft", graph=json.dumps(graph), - features=json.dumps(features), created_by=account.id, environment_variables=environment_variables, conversation_variables=conversation_variables, @@ -218,7 +230,6 @@ class RagPipelineService: # update draft workflow if found else: workflow.graph = json.dumps(graph) - workflow.features = json.dumps(features) workflow.updated_by = account.id workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables @@ -227,8 +238,8 @@ class RagPipelineService: # commit db session changes db.session.commit() - # trigger app workflow events - app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow) + # trigger workflow events TODO + # app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow) # return draft workflow return workflow @@ -269,8 +280,8 @@ class RagPipelineService: # commit db session changes session.add(workflow) - # trigger app workflow events - app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) + # trigger app workflow events TODO + # app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) # return new workflow return workflow @@ -508,46 +519,6 @@ class RagPipelineService: return workflow_node_execution - def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: - """ - Basic mode of chatbot app(expert mode) to workflow - Completion App to Workflow App - - :param app_model: App instance - :param account: Account instance - :param args: dict - :return: - """ - # chatbot convert to workflow mode - workflow_converter = WorkflowConverter() - - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: - raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") - - # convert to workflow - new_app: App = workflow_converter.convert_to_workflow( - app_model=app_model, - account=account, - name=args.get("name", "Default Name"), - icon_type=args.get("icon_type", "emoji"), - icon=args.get("icon", "🤖"), - icon_background=args.get("icon_background", "#FFEAD5"), - ) - - return new_app - - def validate_features_structure(self, app_model: App, features: dict) -> dict: - if app_model.mode == AppMode.ADVANCED_CHAT.value: - return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - elif app_model.mode == AppMode.WORKFLOW.value: - return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - else: - raise ValueError(f"Invalid app mode: {app_model.mode}") - def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict ) -> Optional[Workflow]: @@ -578,38 +549,6 @@ class RagPipelineService: return workflow - def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool: - """ - Delete a workflow - - :param session: SQLAlchemy database session - :param workflow_id: Workflow ID - :param tenant_id: Tenant ID - :return: True if successful - :raises: ValueError if workflow not found - :raises: WorkflowInUseError if workflow is in use - :raises: DraftWorkflowDeletionError if workflow is a draft version - """ - stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) - workflow = session.scalar(stmt) - - if not workflow: - raise ValueError(f"Workflow with ID {workflow_id} not found") - - # Check if workflow is a draft version - if workflow.version == "draft": - raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") - - # Check if this workflow is currently referenced by an app - stmt = select(App).where(App.workflow_id == workflow_id) - app = session.scalar(stmt) - if app: - # Cannot delete a workflow that's currently in use by an app - raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") - - session.delete(workflow) - return True - def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: """ Get second step parameters of rag pipeline @@ -627,3 +566,101 @@ class RagPipelineService: datasource_provider_variables = pipeline_variables.get(datasource_provider, []) shared_variables = pipeline_variables.get("shared", []) return datasource_provider_variables + shared_variables + + def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get("limit", 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == pipeline.tenant_id, + WorkflowRun.app_id == pipeline.id, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + ) + + if args.get("last_id"): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get("last_id"), + ).first() + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) + else: + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id, + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == pipeline.tenant_id, + WorkflowRun.app_id == pipeline.id, + WorkflowRun.id == run_id, + ) + .first() + ) + + return workflow_run + + def get_rag_pipeline_workflow_run_node_executions( + self, + pipeline: Pipeline, + run_id: str, + ) -> list[WorkflowNodeExecution]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + if not workflow_run: + return [] + + # Use the repository to get the node executions + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": pipeline.tenant_id, + "app_id": pipeline.id, + "session_factory": db.session.get_bind(), + } + ) + + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + + return list(node_executions) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py new file mode 100644 index 0000000000..80e7c6af0b --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -0,0 +1,841 @@ +import base64 +import hashlib +import logging +import uuid +from collections.abc import Mapping +from enum import StrEnum +from typing import Optional +from urllib.parse import urlparse +from uuid import uuid4 + +import yaml # type: ignore +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad +from packaging import version +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import PluginDependency +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData +from core.workflow.nodes.tool.entities import ToolNodeData +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account +from models.dataset import Dataset, Pipeline +from models.workflow import Workflow +from services.dataset_service import DatasetCollectionBindingService +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration +from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.rag_pipeline.rag_pipeline import RagPipelineService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" +CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" +IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes +DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB +CURRENT_DSL_VERSION = "0.1.0" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class RagPipelineImportInfo(BaseModel): + id: str + status: ImportStatus + pipeline_id: Optional[str] = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + dataset_id: Optional[str] = None + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # If imported version is newer than current, always return PENDING + if imported_ver > current_ver: + return ImportStatus.PENDING + + # If imported version is older than current's major, return PENDING + if imported_ver.major < current_ver.major: + return ImportStatus.PENDING + + # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS + if imported_ver.minor < current_ver.minor: + return ImportStatus.COMPLETED_WITH_WARNINGS + + # If imported version equals or is older than current's micro, return COMPLETED + return ImportStatus.COMPLETED + + +class RagPipelinePendingData(BaseModel): + import_mode: str + yaml_content: str + name: str | None + description: str | None + icon_type: str | None + icon: str | None + icon_background: str | None + pipeline_id: str | None + + +class CheckDependenciesPendingData(BaseModel): + dependencies: list[PluginDependency] + pipeline_id: str | None + + +class RagPipelineDslService: + def __init__(self, session: Session): + self._session = session + + def import_rag_pipeline( + self, + *, + account: Account, + import_mode: str, + yaml_content: Optional[str] = None, + yaml_url: Optional[str] = None, + pipeline_id: Optional[str] = None, + dataset: Optional[Dataset] = None, + ) -> RagPipelineImportInfo: + """Import an app from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content: str = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + parsed_url = urlparse(yaml_url) + if ( + parsed_url.scheme == "https" + and parsed_url.netloc == "github.com" + and parsed_url.path.endswith((".yml", ".yaml")) + ): + yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") + yaml_url = yaml_url.replace("/blob/", "/") + response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content.decode() + + if len(content) > DSL_MAX_SIZE: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="File size exceeds the limit of 10MB", + ) + + if not content: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Empty content from url", + ) + except Exception as e: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error fetching YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + + # Process YAML content + try: + # Parse YAML to validate format + data = yaml.safe_load(content) + if not isinstance(data, dict): + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: content must be a mapping", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + if not data.get("kind") or data.get("kind") != "rag-pipeline": + data["kind"] = "rag-pipeline" + + imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") + status = _check_version_compatibility(imported_version) + + # Extract app data + pipeline_data = data.get("pipeline") + if not pipeline_data: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing pipeline data in YAML content", + ) + + # If app_id is provided, check if it exists + pipeline = None + if pipeline_id: + stmt = select(Pipeline).where( + Pipeline.id == pipeline_id, + Pipeline.tenant_id == account.current_tenant_id, + ) + pipeline = self._session.scalar(stmt) + + if not pipeline: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Pipeline not found", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + pending_data = RagPipelinePendingData( + import_mode=import_mode, + yaml_content=content, + pipeline_id=pipeline_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending_data.model_dump_json(), + ) + + return RagPipelineImportInfo( + id=import_id, + status=status, + pipeline_id=pipeline_id, + imported_dsl_version=imported_version, + ) + + # Extract dependencies + dependencies = data.get("dependencies", []) + check_dependencies_pending_data = None + if dependencies: + check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] + + # Create or update app + pipeline = self._create_or_update_pipeline( + pipeline=pipeline, + data=data, + account=account, + dependencies=check_dependencies_pending_data, + ) + # create dataset + name = pipeline.name + description = pipeline.description + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") + workflow = data.get("workflow", {}) + graph = workflow.get("graph", {}) + nodes = graph.get("nodes", []) + dataset_id = None + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + if not dataset: + dataset = Dataset( + tenant_id=account.current_tenant_id, + name=name, + description=description, + icon_info={ + "type": icon_type, + "icon": icon, + "background": icon_background, + "url": icon_url, + }, + indexing_technique=knowledge_configuration.index_method.indexing_technique, + created_by=account.id, + retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + runtime_mode="rag_pipeline", + chunk_structure=knowledge_configuration.chunk_structure, + ) + else: + dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.runtime_mode = "rag_pipeline" + dataset.chunk_structure = knowledge_configuration.chunk_structure + if knowledge_configuration.index_method.indexing_technique == "high_quality": + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore + knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore + ) + dataset_collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding_id + dataset.embedding_model = ( + knowledge_configuration.index_method.embedding_setting.embedding_model_name + ) + dataset.embedding_model_provider = ( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name + ) + elif knowledge_configuration.index_method.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + dataset.pipeline_id = pipeline.id + self._session.add(dataset) + self._session.commit() + dataset_id = dataset.id + if not dataset_id: + raise ValueError("DSL is not valid, please check the Knowledge Index node.") + + return RagPipelineImportInfo( + id=import_id, + status=status, + pipeline_id=pipeline.id, + dataset_id=dataset_id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import app") + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + pending_data = RagPipelinePendingData.model_validate_json(pending_data) + data = yaml.safe_load(pending_data.yaml_content) + + pipeline = None + if pending_data.pipeline_id: + stmt = select(Pipeline).where( + Pipeline.id == pending_data.pipeline_id, + Pipeline.tenant_id == account.current_tenant_id, + ) + pipeline = self._session.scalar(stmt) + + # Create or update app + pipeline = self._create_or_update_pipeline( + pipeline=pipeline, + data=data, + account=account, + ) + + # create dataset + name = pipeline.name + description = pipeline.description + icon_type = data.get("rag_pipeline", {}).get("icon_type") + icon = data.get("rag_pipeline", {}).get("icon") + icon_background = data.get("rag_pipeline", {}).get("icon_background") + icon_url = data.get("rag_pipeline", {}).get("icon_url") + workflow = data.get("workflow", {}) + graph = workflow.get("graph", {}) + nodes = graph.get("nodes", []) + dataset_id = None + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + if not dataset: + dataset = Dataset( + tenant_id=account.current_tenant_id, + name=name, + description=description, + icon_info={ + "type": icon_type, + "icon": icon, + "background": icon_background, + "url": icon_url, + }, + indexing_technique=knowledge_configuration.index_method.indexing_technique, + created_by=account.id, + retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + runtime_mode="rag_pipeline", + chunk_structure=knowledge_configuration.chunk_structure, + ) + else: + dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.runtime_mode = "rag_pipeline" + dataset.chunk_structure = knowledge_configuration.chunk_structure + if knowledge_configuration.index_method.indexing_technique == "high_quality": + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore + knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore + ) + dataset_collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding_id + dataset.embedding_model = ( + knowledge_configuration.index_method.embedding_setting.embedding_model_name + ) + dataset.embedding_model_provider = ( + knowledge_configuration.index_method.embedding_setting.embedding_provider_name + ) + elif knowledge_configuration.index_method.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + dataset.pipeline_id = pipeline.id + self._session.add(dataset) + self._session.commit() + dataset_id = dataset.id + if not dataset_id: + raise ValueError("DSL is not valid, please check the Knowledge Index node.") + + # Delete import info from Redis + redis_client.delete(redis_key) + + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.COMPLETED, + pipeline_id=pipeline.id, + dataset_id=dataset_id, + current_dsl_version=CURRENT_DSL_VERSION, + imported_dsl_version=data.get("version", "0.1.0"), + ) + + except Exception as e: + logger.exception("Error confirming import") + return RagPipelineImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def check_dependencies( + self, + *, + pipeline: Pipeline, + ) -> CheckDependenciesResult: + """Check dependencies""" + # Get dependencies from Redis + redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}" + dependencies = redis_client.get(redis_key) + if not dependencies: + return CheckDependenciesResult() + + # Extract dependencies + dependencies = CheckDependenciesPendingData.model_validate_json(dependencies) + + # Get leaked dependencies + leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies + ) + return CheckDependenciesResult( + leaked_dependencies=leaked_dependencies, + ) + + def _create_or_update_pipeline( + self, + *, + pipeline: Optional[Pipeline], + data: dict, + account: Account, + dependencies: Optional[list[PluginDependency]] = None, + ) -> Pipeline: + """Create a new app or update an existing one.""" + pipeline_data = data.get("pipeline", {}) + pipeline_mode = pipeline_data.get("mode") + if not pipeline_mode: + raise ValueError("loss pipeline mode") + # Set icon type + icon_type_value = icon_type or pipeline_data.get("icon_type") + if icon_type_value in ["emoji", "link"]: + icon_type = icon_type_value + else: + icon_type = "emoji" + icon = icon or str(pipeline_data.get("icon", "")) + + if pipeline: + # Update existing pipeline + pipeline.name = pipeline_data.get("name", pipeline.name) + pipeline.description = pipeline_data.get("description", pipeline.description) + pipeline.icon_type = icon_type + pipeline.icon = icon + pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background) + pipeline.updated_by = account.id + else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + + # Create new app + pipeline = Pipeline() + pipeline.id = str(uuid4()) + pipeline.tenant_id = account.current_tenant_id + pipeline.mode = pipeline_mode.value + pipeline.name = pipeline_data.get("name", "") + pipeline.description = pipeline_data.get("description", "") + pipeline.icon_type = icon_type + pipeline.icon = icon + pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF") + pipeline.enable_site = True + pipeline.enable_api = True + pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False) + pipeline.created_by = account.id + pipeline.updated_by = account.id + + self._session.add(pipeline) + self._session.commit() + # save dependencies + if dependencies: + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}", + IMPORT_INFO_REDIS_EXPIRY, + CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), + ) + + # Initialize pipeline based on mode + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for rag pipeline") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) + rag_pipeline_variables = [ + variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list + ] + + rag_pipeline_service = RagPipelineService() + current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, + tenant_id=pipeline.tenant_id, + ) + ) + ] + rag_pipeline_service.sync_draft_workflow( + pipeline=pipeline, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + + return pipeline + + @classmethod + def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str: + """ + Export pipeline + :param pipeline: Pipeline instance + :param include_secret: Whether include secret variable + :return: + """ + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "rag_pipeline", + "pipeline": { + "name": pipeline.name, + "mode": pipeline.mode, + "icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon, + "icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background, + "description": pipeline.description, + "use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon, + }, + } + + cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) + + return yaml.dump(export_data, allow_unicode=True) # type: ignore + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param pipeline: Pipeline instance + """ + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + workflow_dict = workflow.to_dict(include_secret=include_secret) + for node in workflow_dict.get("graph", {}).get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) + for dataset_id in dataset_ids + ] + export_data["workflow"] = workflow_dict + dependencies = cls._extract_dependencies_from_workflow(workflow) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None: + """ + Append model config export data + :param export_data: export data + :param pipeline: Pipeline instance + """ + app_model_config = pipeline.app_model_config + if not app_model_config: + raise ValueError("Missing app configuration, please check.") + + export_data["model_config"] = app_model_config.to_dict() + dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=pipeline.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: + """ + Extract dependencies from workflow + :param workflow: Workflow instance + :return: dependencies list format like ["langgenius/google"] + """ + graph = workflow.graph_dict + dependencies = cls._extract_dependencies_from_workflow_graph(graph) + return dependencies + + @classmethod + def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]: + """ + Extract dependencies from workflow graph + :param graph: Workflow graph + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + for node in graph.get("nodes", []): + try: + typ = node.get("data", {}).get("type") + match typ: + case NodeType.TOOL.value: + tool_entity = ToolNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), + ) + case NodeType.LLM.value: + llm_entity = LLMNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), + ) + case NodeType.QUESTION_CLASSIFIER.value: + question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + question_classifier_entity.model.provider + ), + ) + case NodeType.PARAMETER_EXTRACTOR.value: + parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + parameter_extractor_entity.model.provider + ), + ) + case NodeType.KNOWLEDGE_RETRIEVAL.value: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + if knowledge_retrieval_entity.retrieval_mode == "multiple": + if knowledge_retrieval_entity.multiple_retrieval_config: + if ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "reranking_model" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider + ), + ) + elif ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "weighted_score" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.weights: + vector_setting = ( + knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting + ) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + vector_setting.embedding_provider_name + ), + ) + elif knowledge_retrieval_entity.retrieval_mode == "single": + model_config = knowledge_retrieval_entity.single_retrieval_config + if model_config: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + model_config.model.provider + ), + ) + case _: + # TODO: Handle default case or unknown node types + pass + except Exception as e: + logger.exception("Error extracting node dependency", exc_info=e) + + return dependencies + + @classmethod + def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]: + """ + Extract dependencies from model config + :param model_config: model config dict + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + + try: + # completion model + model_dict = model_config.get("model", {}) + if model_dict: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", "")) + ) + + # reranking model + dataset_configs = model_config.get("dataset_configs", {}) + if dataset_configs: + for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []): + if dataset_config.get("reranking_model"): + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + dataset_config.get("reranking_model", {}) + .get("reranking_provider_name", {}) + .get("provider") + ) + ) + + # tools + agent_configs = model_config.get("agent_mode", {}) + if agent_configs: + for agent_config in agent_configs.get("tools", []): + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id")) + ) + + except Exception as e: + logger.exception("Error extracting model config dependency", exc_info=e) + + return dependencies + + @classmethod + def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + """ + Returns the leaked dependencies in current workspace + """ + dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + if not dependencies: + return [] + + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + + @staticmethod + def _generate_aes_key(tenant_id: str) -> bytes: + """Generate AES key based on tenant_id""" + return hashlib.sha256(tenant_id.encode()).digest() + + @classmethod + def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: + """Encrypt dataset_id using AES-CBC mode""" + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) + return base64.b64encode(ct_bytes).decode() + + @classmethod + def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: + """AES decryption""" + try: + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) + return pt.decode() + except Exception: + return None diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3ccd14415d..daf3773309 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,6 +5,7 @@ from pathlib import Path from sqlalchemy.orm import Session from configs import dify_config +from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -16,7 +17,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from models.tools import BuiltinDatasourceProvider, BuiltinToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -286,6 +287,67 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) + @staticmethod + def list_rag_pipeline_datasources(tenant_id: str) -> list[DatasourceProviderApiEntity]: + """ + list rag pipeline datasources + """ + # get all builtin providers + datasource_provider_controllers = ToolManager.list_datasource_providers(tenant_id) + + with db.session.no_autoflush: + # get all user added providers + db_providers: list[BuiltinDatasourceProvider] = ( + db.session.query(BuiltinDatasourceProvider) + .filter(BuiltinDatasourceProvider.tenant_id == tenant_id) + .all() + or [] + ) + + # find provider + def find_provider(provider): + return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + + result: list[DatasourceProviderApiEntity] = [] + + for provider_controller in datasource_provider_controllers: + try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue + + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_datasource_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.entity.identity.name), + decrypt_credentials=True, + ) + + # add icon + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) + + datasources = provider_controller.get_datasources() + for datasource in datasources or []: + user_builtin_provider.datasources.append( + ToolTransformService.convert_datasource_entity_to_api_entity( + tenant_id=tenant_id, + datasource=datasource, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) + + result.append(user_builtin_provider) + except Exception as e: + raise e + + return result + @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..e0c1ce7217 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,11 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import DatasourceProviderType from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -21,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinDatasourceProvider, BuiltinToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -140,6 +145,64 @@ class ToolTransformService: return result + @classmethod + def builtin_datasource_provider_to_user_provider( + cls, + provider_controller: DatasourcePluginProviderController, + db_provider: Optional[BuiltinDatasourceProvider], + decrypt_credentials: bool = True, + ) -> DatasourceProviderApiEntity: + """ + convert provider controller to user provider + """ + result = DatasourceProviderApiEntity( + id=provider_controller.entity.identity.name, + author=provider_controller.entity.identity.author, + name=provider_controller.entity.identity.name, + description=provider_controller.entity.identity.description, + icon=provider_controller.entity.identity.icon, + label=provider_controller.entity.identity.label, + type=DatasourceProviderType.RAG_PIPELINE, + masked_credentials={}, + is_team_authorization=False, + plugin_id=provider_controller.plugin_id, + plugin_unique_identifier=provider_controller.plugin_unique_identifier, + datasources=[], + ) + + # get credentials schema + schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + + for name, value in schema.items(): + if result.masked_credentials: + result.masked_credentials[name] = "" + + # check if the provider need credentials + if not provider_controller.need_credentials: + result.is_team_authorization = True + result.allow_delete = False + elif db_provider: + result.is_team_authorization = True + + if decrypt_credentials: + credentials = db_provider.credentials + + # init tool configuration + tool_configuration = ProviderConfigEncrypter( + tenant_id=db_provider.tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt(data=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + + result.masked_credentials = masked_credentials + result.original_credentials = decrypted_credentials + + return result + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -304,3 +367,48 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + + @staticmethod + def convert_datasource_entity_to_api_entity( + datasource: DatasourcePlugin, + tenant_id: str, + credentials: dict | None = None, + labels: list[str] | None = None, + ) -> DatasourceApiEntity: + """ + convert tool to user tool + """ + # fork tool runtime + datasource = datasource.fork_datasource_runtime( + runtime=DatasourceRuntime( + credentials=credentials or {}, + tenant_id=tenant_id, + ) + ) + + # get datasource parameters + parameters = datasource.entity.parameters or [] + # get datasource runtime parameters + runtime_parameters = datasource.get_runtime_parameters() + # override parameters + current_parameters = parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return DatasourceApiEntity( + author=datasource.entity.identity.author, + name=datasource.entity.identity.name, + label=datasource.entity.identity.label, + description=datasource.entity.description.human if datasource.entity.description else I18nObject(en_US=""), + output_schema=datasource.entity.output_schema, + parameters=current_parameters, + labels=labels or [], + ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 63e3791147..c0f4578474 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -203,7 +203,6 @@ class WorkflowService: type=draft_workflow.type, version=str(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, - features=draft_workflow.features, created_by=account.id, environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables, From e710a8402c02dd83b72c7d990334f79fadde0745 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 15 May 2025 16:07:17 +0800 Subject: [PATCH 015/155] r2 --- .../datasource/__base/datasource_plugin.py | 6 +- api/core/datasource/datasource_manager.py | 12 +- .../entities/datasource_entities.py | 3 - api/core/plugin/entities/plugin.py | 3 + api/core/plugin/impl/tool.py | 32 ++++- ..._15_1558-b35c3db83d09_add_pipeline_info.py | 113 ++++++++++++++++++ api/models/dataset.py | 6 +- api/models/workflow.py | 12 +- 8 files changed, 165 insertions(+), 22 deletions(-) create mode 100644 api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 86bd66a3f9..8fb89e1172 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -2,13 +2,13 @@ from collections.abc import Generator from typing import Any, Optional from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, ) +from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -44,7 +44,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = DatasourceManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) @@ -64,7 +64,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = DatasourceManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 195d430015..fa141a679a 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -7,8 +7,8 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.common_entities import I18nObject from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.datasource.errors import ToolProviderNotFoundError -from core.plugin.manager.tool import PluginToolManager +from core.datasource.errors import DatasourceProviderNotFoundError +from core.plugin.impl.tool import PluginToolManager logger = logging.getLogger(__name__) @@ -37,9 +37,9 @@ class DatasourceManager: return datasource_plugin_providers[provider] manager = PluginToolManager() - provider_entity = manager.fetch_tool_provider(tenant_id, provider) + provider_entity = manager.fetch_datasource_provider(tenant_id, provider) if not provider_entity: - raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") controller = DatasourcePluginProviderController( entity=provider_entity.declaration, @@ -73,7 +73,7 @@ class DatasourceManager: if provider_type == DatasourceProviderType.RAG_PIPELINE: return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found") @classmethod def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: @@ -81,7 +81,7 @@ class DatasourceManager: list all the datasource providers """ manager = PluginToolManager() - provider_entities = manager.fetch_tool_providers(tenant_id) + provider_entities = manager.fetch_datasources(tenant_id) return [ DatasourcePluginProviderController( entity=provider.declaration, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 80e89ef1a9..6fc23e88cc 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -321,9 +321,6 @@ class DatasourceEntity(BaseModel): output_schema: Optional[dict] = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index bdf7d5ce1f..85d4d130ba 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -192,6 +192,9 @@ class ToolProviderID(GenericProviderID): if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: self.plugin_name = f"{self.provider_name}_tool" +class DatasourceProviderID(GenericProviderID): + def __init__(self, value: str, is_hardcoded: bool = False) -> None: + super().__init__(value, is_hardcoded) class PluginDependency(BaseModel): class Type(enum.StrEnum): diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index f4360a70de..54f5418bb4 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, @@ -76,6 +76,36 @@ class PluginToolManager(BasePluginClient): return response + def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: + """ + Fetch datasource provider for the given tenant and plugin. + """ + datasource_provider_id = DatasourceProviderID(provider) + + def transformer(json_response: dict[str, Any]) -> dict: + data = json_response.get("data") + if data: + for tool in data.get("declaration", {}).get("tools", []): + tool["identity"]["provider"] = datasource_provider_id.provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasource", + PluginDatasourceProviderEntity, + params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id}, + transformer=transformer, + ) + + response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in response.declaration.tools: + tool.identity.provider = response.declaration.identity.name + + return response + def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: """ Fetch tool provider for the given tenant and plugin. diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py new file mode 100644 index 0000000000..89fcc6aa29 --- /dev/null +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -0,0 +1,113 @@ +"""add_pipeline_info + +Revision ID: b35c3db83d09 +Revises: d28f2004b072 +Create Date: 2025-05-15 15:58:05.179877 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b35c3db83d09' +down_revision = 'd28f2004b072' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('pipeline_built_in_templates', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('icon', sa.JSON(), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=False), + sa.Column('privacy_policy', sa.String(length=255), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('language', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') + ) + op.create_table('pipeline_customized_templates', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('icon', sa.JSON(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('language', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') + ) + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('pipelines', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=True), + sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_pkey') + ) + op.create_table('tool_builtin_datasource_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=256), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider') + ) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True)) + batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True)) + batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('rag_pipeline_variables') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('chunk_structure') + batch_op.drop_column('pipeline_id') + batch_op.drop_column('runtime_mode') + batch_op.drop_column('icon_info') + batch_op.drop_column('keyword_number') + + op.drop_table('tool_builtin_datasource_providers') + op.drop_table('pipelines') + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.drop_index('pipeline_customized_template_tenant_idx') + + op.drop_table('pipeline_customized_templates') + op.drop_table('pipeline_built_in_templates') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 292f4aacfd..e60f110aef 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1149,7 +1149,7 @@ class DatasetMetadataBinding(Base): created_by = db.Column(StringUUID, nullable=False) -class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] +class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_built_in_templates" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) @@ -1167,7 +1167,7 @@ class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] +class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), @@ -1187,7 +1187,7 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Pipeline(db.Model): # type: ignore[name-defined] +class Pipeline(Base): # type: ignore[name-defined] __tablename__ = "pipelines" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) diff --git a/api/models/workflow.py b/api/models/workflow.py index 2fda5431c3..b6b56ad520 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -128,8 +128,8 @@ class Workflow(Base): _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) - _pipeline_variables: Mapped[str] = mapped_column( - "conversation_variables", db.Text, nullable=False, server_default="{}" + _rag_pipeline_variables: Mapped[str] = mapped_column( + "rag_pipeline_variables", db.Text, nullable=False, server_default="{}" ) @classmethod @@ -354,10 +354,10 @@ class Workflow(Base): @property def pipeline_variables(self) -> dict[str, Sequence[Variable]]: # TODO: find some way to init `self._conversation_variables` when instance created. - if self._pipeline_variables is None: - self._pipeline_variables = "{}" + if self._rag_pipeline_variables is None: + self._rag_pipeline_variables = "{}" - variables_dict: dict[str, Any] = json.loads(self._pipeline_variables) + variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) results = {} for k, v in variables_dict.items(): results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] @@ -365,7 +365,7 @@ class Workflow(Base): @pipeline_variables.setter def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: - self._pipeline_variables = json.dumps( + self._rag_pipeline_variables = json.dumps( {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, ensure_ascii=False, ) From 93ac6d37e901cc85eb5d10516d213c1893b26ac6 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 15 May 2025 16:44:55 +0800 Subject: [PATCH 016/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 1 + .../entities/datasource_entities.py | 2 +- api/core/plugin/entities/plugin.py | 2 + .../knowledge_index/knowledge_index_node.py | 62 ++++++++++++++++--- api/services/dataset_service.py | 58 +---------------- 5 files changed, 58 insertions(+), 67 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 99d3b73d33..b348f7a796 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -469,6 +469,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): rag_pipeline_service = RagPipelineService() return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) + class RagPipelineConfigApi(Resource): """Resource for rag pipeline configuration.""" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 6fc23e88cc..aa31a7f86a 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from enum import Enum from typing import Any, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator +from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY from core.entities.provider_entities import ProviderConfig diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 85d4d130ba..260d4f12db 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -192,10 +192,12 @@ class ToolProviderID(GenericProviderID): if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: self.plugin_name = f"{self.provider_name}_tool" + class DatasourceProviderID(GenericProviderID): def __init__(self, value: str, is_hardcoded: bool = False) -> None: super().__init__(value, is_hardcoded) + class PluginDependency(BaseModel): class Type(enum.StrEnum): Github = PluginInstallationSource.Github.value diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 5f9ac78097..b8901e5cce 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,13 @@ +import datetime import logging import time from typing import Any, cast +from flask_login import current_user + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult @@ -11,7 +17,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, Document, RateLimitLog from models.workflow import WorkflowNodeExecutionStatus -from services.dataset_service import DocumentService +from services.dataset_service import DatasetCollectionBindingService from services.feature_service import FeatureService from .entities import KnowledgeIndexNodeData @@ -109,14 +115,52 @@ class KnowledgeIndexNode(LLMNode): if not document: raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") - DocumentService.invoke_knowledge_index( - dataset=dataset, - document=document, - chunks=chunks, - chunk_structure=node_data.chunk_structure, - index_method=node_data.index_method, - retrieval_setting=node_data.retrieval_setting, - ) + retrieval_setting = node_data.retrieval_setting + index_method = node_data.index_method + if not dataset.indexing_technique: + if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = index_method.indexing_technique + if index_method.indexing_technique == "high_quality": + model_manager = ModelManager() + if ( + index_method.embedding_setting.embedding_model + and index_method.embedding_setting.embedding_model_provider + ): + dataset_embedding_model = index_method.embedding_setting.embedding_model + dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model + ) # type: ignore + index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor() + index_processor.index(dataset, document, chunks) + + # update document status + document.indexing_status = "completed" + document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() return { "dataset_id": dataset.id, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 02954cdb44..df31e0f7ca 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Literal, Optional +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func, select @@ -20,9 +20,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexType -from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -1516,60 +1514,6 @@ class DocumentService: return documents, batch - @staticmethod - def invoke_knowledge_index( - dataset: Dataset, - document: Document, - chunks: list[Any], - index_method: IndexMethod, - retrieval_setting: RetrievalSetting, - chunk_structure: Literal["text_model", "hierarchical_model"], - ): - if not dataset.indexing_technique: - if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: - raise ValueError("Indexing technique is invalid") - - dataset.indexing_technique = index_method.indexing_technique - if index_method.indexing_technique == "high_quality": - model_manager = ModelManager() - if ( - index_method.embedding_setting.embedding_model - and index_method.embedding_setting.embedding_model_provider - ): - dataset_embedding_model = index_method.embedding_setting.embedding_model - dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model - ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - - dataset.retrieval_model = ( - retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model - ) # type: ignore - index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - index_processor.index(dataset, document, chunks) - - # update document status - document.indexing_status = "completed" - document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() - @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size From 4ff971c8a36f30321b4fa49003c205feaecd660a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 15 May 2025 17:19:14 +0800 Subject: [PATCH 017/155] r2 --- api/controllers/console/__init__.py | 7 +++++- .../processor/paragraph_index_processor.py | 3 +-- .../processor/parent_child_index_processor.py | 3 +-- api/core/rag/models/document.py | 25 +++++++++++++++++++ .../nodes/knowledge_index/__init__.py | 4 +-- .../knowledge_index/knowledge_index_node.py | 4 +-- .../nodes/knowledge_retrieval/entities.py | 3 +-- api/fields/workflow_fields.py | 4 --- .../database/database_retrieval.py | 12 ++++----- api/services/rag_pipeline/rag_pipeline.py | 14 ++++------- .../rag_pipeline/rag_pipeline_dsl_service.py | 24 ++++++++++++++++-- 11 files changed, 71 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 998ec2e3bf..c55d3fbb66 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,7 +3,6 @@ from flask import Blueprint from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi -from .datasets.rag_pipeline import data_source from .explore.audio import ChatAudioApi, ChatTextApi from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from .explore.conversation import ( @@ -84,6 +83,12 @@ from .datasets import ( metadata, website, ) +from .datasets.rag_pipeline import ( + rag_pipeline, + rag_pipeline_datasets, + rag_pipeline_import, + rag_pipeline_workflow, +) # Import explore controllers from .explore import ( diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 43d201af73..155aae61d4 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -12,9 +12,8 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols -from core.workflow.nodes.knowledge_index.entities import GeneralStructureChunk from libs import helper from models.dataset import Dataset, DatasetProcessRule from services.entities.knowledge_entities.knowledge_entities import Rule diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index ce64bb2a54..5279864441 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,8 +13,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document -from core.workflow.nodes.knowledge_index.entities import ParentChildStructureChunk +from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper from models.dataset import ChildChunk, Dataset, DocumentSegment diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 421cdc05df..52795bbadf 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -35,6 +35,31 @@ class Document(BaseModel): children: Optional[list[ChildDocument]] = None +class GeneralStructureChunk(BaseModel): + """ + General Structure Chunk. + """ + + general_chunk: list[str] + + +class ParentChildChunk(BaseModel): + """ + Parent Child Chunk. + """ + + parent_content: str + child_contents: list[str] + + +class ParentChildStructureChunk(BaseModel): + """ + Parent Child Structure Chunk. + """ + + parent_child_chunks: list[ParentChildChunk] + + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py index 01d59b87b2..23897a1e42 100644 --- a/api/core/workflow/nodes/knowledge_index/__init__.py +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -1,3 +1,3 @@ -from .knowledge_index_node import KnowledgeRetrievalNode +from .knowledge_index_node import KnowledgeIndexNode -__all__ = ["KnowledgeRetrievalNode"] +__all__ = ["KnowledgeIndexNode"] diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b8901e5cce..f039b233a5 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,7 @@ import datetime import logging import time -from typing import Any, cast +from typing import Any, cast, Mapping from flask_login import current_user @@ -106,7 +106,7 @@ class KnowledgeIndexNode(LLMNode): error_type=type(e).__name__, ) - def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[Any]) -> Any: + def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any: dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() if not dataset: raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 17b3308a06..8c702b74ee 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -57,8 +57,7 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): - """ - Model Config. + provider: str name: str diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 45112d42f9..a37ae7856d 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -62,10 +62,6 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), - "pipeline_variables": fields.Dict( - keys=fields.String, - values=fields.List(fields.Nested(pipeline_variable_fields)), - ), } workflow_partial_fields = { diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 10dd044493..f6ab5c9064 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -2,12 +2,12 @@ from typing import Optional from extensions.ext_database import db from models.dataset import Pipeline, PipelineBuiltInTemplate -from services.app_dsl_service import AppDslService -from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase -from services.recommend_app.recommend_app_type import RecommendAppType +from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase +from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService -class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): +class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ @@ -21,7 +21,7 @@ class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): return result def get_type(self) -> str: - return RecommendAppType.DATABASE + return PipelineTemplateType.DATABASE @classmethod def fetch_pipeline_templates_from_db(cls, language: str) -> dict: @@ -61,5 +61,5 @@ class DatabasePipelineTemplateRetrieval(RecommendAppRetrievalBase): "name": pipeline.name, "icon": pipeline.icon, "mode": pipeline.mode, - "export_data": AppDslService.export_dsl(app_model=pipeline), + "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1e6447d80f..2275c32f63 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -13,8 +13,7 @@ from sqlalchemy.orm import Session import contexts from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.repository.repository_factory import RepositoryFactory -from core.repository.workflow_node_execution_repository import OrderConfig +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError @@ -24,6 +23,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.repository.workflow_node_execution_repository import OrderConfig from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -650,13 +650,9 @@ class RagPipelineService: if not workflow_run: return [] - # Use the repository to get the node executions - repository = RepositoryFactory.create_workflow_node_execution_repository( - params={ - "tenant_id": pipeline.tenant_id, - "app_id": pipeline.id, - "session_factory": db.session.get_bind(), - } + # Use the repository to get the node execution + repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id ) # Use the repository to get the node executions with ordering diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 80e7c6af0b..e50caa9756 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -25,12 +25,12 @@ from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.tool.entities import ToolNodeData +from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import variable_factory from models import Account -from models.dataset import Dataset, Pipeline +from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.workflow import Workflow -from services.dataset_service import DatasetCollectionBindingService from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -306,6 +306,26 @@ class RagPipelineDslService: knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore ) + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + DatasetCollectionBinding.model_name == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + DatasetCollectionBinding.type == "dataset", + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type="dataset", + ) + db.session.add(dataset_collection_binding) + db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( From 7b0d38f7d3b30813df08ff0a9171f7d38affd48a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 16 May 2025 12:02:35 +0800 Subject: [PATCH 018/155] r2 --- .../console/datasets/rag_pipeline/rag_pipeline.py | 15 +++++++++++++-- api/services/rag_pipeline/rag_pipeline.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 864f3644d5..4f1dfb6391 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restful import Resource, reqparse # type: ignore # type: ignore +from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session from controllers.console import api from controllers.console.wraps import ( @@ -9,6 +10,7 @@ from controllers.console.wraps import ( enterprise_license_required, setup_required, ) +from extensions.ext_database import db from libs.login import login_required from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -91,6 +93,15 @@ class CustomizedPipelineTemplateApi(Resource): RagPipelineService.delete_customized_pipeline_template(template_id) return 200 + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, template_id: str): + with Session(db.engine) as session: + dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id) + return {"data": dsl}, 200 + api.add_resource( PipelineTemplateListApi, @@ -102,5 +113,5 @@ api.add_resource( ) api.add_resource( CustomizedPipelineTemplateApi, - "/rag/pipeline/templates/", + "/rag/pipeline/customized/templates/", ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 2275c32f63..166673130f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -41,6 +41,7 @@ from models.workflow import ( from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class RagPipelineService: @@ -115,6 +116,20 @@ class RagPipelineService: db.delete(customized_template) db.commit() + @classmethod + def export_template_rag_pipeline_dsl(cls, template_id: str) -> str: + """ + Export template rag pipeline dsl + """ + template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + if not template: + raise ValueError("Customized pipeline template not found.") + pipeline = db.session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found.") + + return RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) + def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: """ Get draft workflow From 613b94a6e6c0783d8844c97f17cad11a4ba91188 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 16 May 2025 13:45:47 +0800 Subject: [PATCH 019/155] r2 --- .../datasets/rag_pipeline/rag_pipeline.py | 11 ++++++++++- api/services/rag_pipeline/rag_pipeline.py | 16 ---------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 4f1dfb6391..e674a89480 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -12,8 +12,10 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from libs.login import login_required +from models.dataset import Pipeline, PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService logger = logging.getLogger(__name__) @@ -99,7 +101,14 @@ class CustomizedPipelineTemplateApi(Resource): @enterprise_license_required def post(self, template_id: str): with Session(db.engine) as session: - dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id) + template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + if not template: + raise ValueError("Customized pipeline template not found.") + pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found.") + + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) return {"data": dsl}, 200 diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 166673130f..bc2cfdeeb3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -41,8 +41,6 @@ from models.workflow import ( from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService - class RagPipelineService: @staticmethod @@ -116,20 +114,6 @@ class RagPipelineService: db.delete(customized_template) db.commit() - @classmethod - def export_template_rag_pipeline_dsl(cls, template_id: str) -> str: - """ - Export template rag pipeline dsl - """ - template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() - if not template: - raise ValueError("Customized pipeline template not found.") - pipeline = db.session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() - if not pipeline: - raise ValueError("Pipeline not found.") - - return RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) - def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: """ Get draft workflow From 9e72afee3c9fb3671a2b703506cd70a2e1c43de4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 16 May 2025 14:00:35 +0800 Subject: [PATCH 020/155] r2 --- api/services/dataset_service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index df31e0f7ca..ac45981ee5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -244,18 +244,20 @@ class DatasetService: rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): # check if dataset name already exists - if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) dataset = Dataset( + tenant_id=tenant_id, name=rag_pipeline_dataset_create_entity.name, description=rag_pipeline_dataset_create_entity.description, permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, + created_by=current_user.id ) db.session.add(dataset) db.session.commit() @@ -267,7 +269,7 @@ class DatasetService: rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): # check if dataset name already exists - if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) From 8bea88c8cc6b862fac02fb89b1320608ced06fc3 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 16 May 2025 17:22:17 +0800 Subject: [PATCH 021/155] r2 --- .../feature/hosted_service/__init__.py | 2 +- .../datasets/rag_pipeline/rag_pipeline.py | 4 +-- .../rag_pipeline/rag_pipeline_workflow.py | 18 +++++----- api/fields/rag_pipeline_fields.py | 1 + api/models/dataset.py | 7 +++- api/models/workflow.py | 6 ++-- api/services/dataset_service.py | 24 +++++++++++--- .../database/database_retrieval.py | 33 +++++++++++++++---- .../pipeline_template_factory.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 14 +++++--- .../rag_pipeline/rag_pipeline_dsl_service.py | 10 ++---- 11 files changed, 80 insertions(+), 41 deletions(-) diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 7633ffcf8a..3e57f24ff5 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings): HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field( description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,", - default="remote", + default="database", ) HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field( diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index e674a89480..44296d5a31 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource): @account_initialization_required @enterprise_license_required def get(self): - type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + type = request.args.get("type", default="built-in", type=str) language = request.args.get("language", default="en-US", type=str) # get pipeline templates pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) @@ -107,7 +107,7 @@ class CustomizedPipelineTemplateApi(Resource): pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found.") - + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) return {"data": dsl}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index b348f7a796..c76014d0a3 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource): if "application/json" in content_type: parser = reqparse.RequestParser() parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") - parser.add_argument("features", type=dict, required=True, nullable=False, location="json") parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") - parser.add_argument("pipeline_variables", type=dict, required=False, location="json") + parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), - "pipeline_variables": data.get("pipeline_variables"), + "rag_pipeline_variables": data.get("rag_pipeline_variables"), } except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 @@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - pipeline_variables_list = args.get("pipeline_variables") or {} - pipeline_variables = { + rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {} + rag_pipeline_variables = { k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] - for k, v in pipeline_variables_list.items() + for k, v in rag_pipeline_variables_list.items() } rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=args["graph"], - features=args["features"], unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, - pipeline_variables=pipeline_variables, + rag_pipeline_variables=rag_pipeline_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + def get(self, pipeline_id): return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } @@ -792,5 +790,5 @@ api.add_resource( ) api.add_resource( DatasourceListApi, - "/rag/pipelines/datasources", + "/rag/pipelines/datasource-plugins", ) diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py index 0bb74e3259..cedc13ed0d 100644 --- a/api/fields/rag_pipeline_fields.py +++ b/api/fields/rag_pipeline_fields.py @@ -153,6 +153,7 @@ pipeline_import_fields = { "id": fields.String, "status": fields.String, "pipeline_id": fields.String, + "dataset_id": fields.String, "current_dsl_version": fields.String, "imported_dsl_version": fields.String, "error": fields.String, diff --git a/api/models/dataset.py b/api/models/dataset.py index e60f110aef..0ed59c898f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property + def pipeline(self): + return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" @@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined] tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) - mode = db.Column(db.String(255), nullable=False) workflow_id = db.Column(StringUUID, nullable=True) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() diff --git a/api/models/workflow.py b/api/models/workflow.py index b6b56ad520..5cb413b6a6 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -352,7 +352,7 @@ class Workflow(Base): ) @property - def pipeline_variables(self) -> dict[str, Sequence[Variable]]: + def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._rag_pipeline_variables is None: self._rag_pipeline_variables = "{}" @@ -363,8 +363,8 @@ class Workflow(Base): results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] return results - @pipeline_variables.setter - def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: + @rag_pipeline_variables.setter + def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: self._rag_pipeline_variables = json.dumps( {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, ensure_ascii=False, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac45981ee5..0f5069f052 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -40,6 +40,7 @@ from models.dataset import ( Document, DocumentSegment, ExternalKnowledgeBindings, + Pipeline, ) from models.model import UploadFile from models.source import DataSourceOauthBinding @@ -248,6 +249,15 @@ class DatasetService: raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) + + pipeline = Pipeline( + tenant_id=tenant_id, + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + created_by=current_user.id + ) + db.session.add(pipeline) + db.session.flush() dataset = Dataset( tenant_id=tenant_id, @@ -257,7 +267,8 @@ class DatasetService: provider="vendor", runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, - created_by=current_user.id + created_by=current_user.id, + pipeline_id=pipeline.id ) db.session.add(dataset) db.session.commit() @@ -282,10 +293,13 @@ class DatasetService: runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, ) - - if rag_pipeline_dataset_create_entity.yaml_content: - rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline( - current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( + account=current_user, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=rag_pipeline_dataset_create_entity.yaml_content, + dataset=dataset ) return { "id": rag_pipeline_import_info.id, diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index f6ab5c9064..bda29c804c 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,10 +1,9 @@ from typing import Optional from extensions.ext_database import db -from models.dataset import Pipeline, PipelineBuiltInTemplate +from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType -#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): @@ -30,11 +29,32 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_templates = ( - db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() - ) + + pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter( + PipelineBuiltInTemplate.language == language + ).all() + + recommended_pipelines_results = [] + for pipeline_built_in_template in pipeline_built_in_templates: + pipeline_model: Pipeline = pipeline_built_in_template.pipeline + + recommended_pipeline_result = { + 'id': pipeline_built_in_template.id, + 'name': pipeline_built_in_template.name, + 'pipeline_id': pipeline_model.id, + 'description': pipeline_built_in_template.description, + 'icon': pipeline_built_in_template.icon, + 'copyright': pipeline_built_in_template.copyright, + 'privacy_policy': pipeline_built_in_template.privacy_policy, + 'position': pipeline_built_in_template.position, + } + dataset: Dataset = pipeline_model.dataset + if dataset: + recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure + recommended_pipelines_results.append(recommended_pipeline_result) + + return {'pipeline_templates': recommended_pipelines_results} - return {"pipeline_templates": pipeline_templates} @classmethod def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: @@ -43,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param pipeline_id: Pipeline ID :return: """ + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService # is in public recommended list pipeline_template = ( db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py index 37e40bf6a0..aa8a6298d7 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -15,7 +15,7 @@ class PipelineTemplateRetrievalFactory: return DatabasePipelineTemplateRetrieval case PipelineTemplateType.DATABASE: return DatabasePipelineTemplateRetrieval - case PipelineTemplateType.BUILT_IN: + case PipelineTemplateType.BUILTIN: return BuiltInPipelineTemplateRetrieval case _: raise ValueError(f"invalid fetch recommended apps mode: {mode}") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index bc2cfdeeb3..f380bc32d7 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -42,6 +42,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import PipelineT from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory + class RagPipelineService: @staticmethod def get_pipeline_templates( @@ -49,7 +50,7 @@ class RagPipelineService: ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() @@ -57,7 +58,7 @@ class RagPipelineService: return result.get("pipeline_templates") else: mode = "customized" - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) return result.get("pipeline_templates") @@ -200,7 +201,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], - pipeline_variables: dict[str, Sequence[Variable]], + rag_pipeline_variables: dict[str, Sequence[Variable]], ) -> Workflow: """ Sync draft workflow @@ -217,15 +218,18 @@ class RagPipelineService: workflow = Workflow( tenant_id=pipeline.tenant_id, app_id=pipeline.id, + features="{}", type=WorkflowType.RAG_PIPELINE.value, version="draft", graph=json.dumps(graph), created_by=account.id, environment_variables=environment_variables, conversation_variables=conversation_variables, - pipeline_variables=pipeline_variables, + rag_pipeline_variables=rag_pipeline_variables, ) db.session.add(workflow) + db.session.flush() + pipeline.workflow_id = workflow.id # update draft workflow if found else: workflow.graph = json.dumps(graph) @@ -233,7 +237,7 @@ class RagPipelineService: workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables - workflow.pipeline_variables = pipeline_variables + workflow.rag_pipeline_variables = rag_pipeline_variables # commit db session changes db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index e50caa9756..3664c988e5 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -516,17 +516,14 @@ class RagPipelineDslService: dependencies: Optional[list[PluginDependency]] = None, ) -> Pipeline: """Create a new app or update an existing one.""" - pipeline_data = data.get("pipeline", {}) - pipeline_mode = pipeline_data.get("mode") - if not pipeline_mode: - raise ValueError("loss pipeline mode") + pipeline_data = data.get("rag_pipeline", {}) # Set icon type - icon_type_value = icon_type or pipeline_data.get("icon_type") + icon_type_value = pipeline_data.get("icon_type") if icon_type_value in ["emoji", "link"]: icon_type = icon_type_value else: icon_type = "emoji" - icon = icon or str(pipeline_data.get("icon", "")) + icon = str(pipeline_data.get("icon", "")) if pipeline: # Update existing pipeline @@ -544,7 +541,6 @@ class RagPipelineDslService: pipeline = Pipeline() pipeline.id = str(uuid4()) pipeline.tenant_id = account.current_tenant_id - pipeline.mode = pipeline_mode.value pipeline.name = pipeline_data.get("name", "") pipeline.description = pipeline_data.get("description", "") pipeline.icon_type = icon_type From c5a2f43ceb0d869c838e3e468a524e3e5b7a9d4b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 16 May 2025 18:42:07 +0800 Subject: [PATCH 022/155] refactor: replace BuiltinToolManageService with RagPipelineManageService for datasource management and remove unused datasource engine and related code --- .../datasets/rag_pipeline/rag_pipeline.py | 4 +- .../rag_pipeline/rag_pipeline_workflow.py | 11 +- .../datasource/__base/datasource_plugin.py | 49 +- .../datasource/__base/datasource_provider.py | 11 +- api/core/datasource/datasource_engine.py | 224 ------- api/core/datasource/datasource_manager.py | 15 +- api/core/datasource/entities/api_entities.py | 3 +- api/core/datasource/entities/constants.py | 1 - .../entities/datasource_entities.py | 250 +------- api/core/datasource/entities/values.py | 111 ---- api/core/plugin/impl/datasource.py | 97 +--- api/core/plugin/impl/tool.py | 64 +- api/core/plugin/utils/converter.py | 7 - api/core/tools/tool_manager.py | 26 - .../nodes/datasource/datasource_node.py | 263 +-------- .../workflow/nodes/datasource/entities.py | 6 +- .../knowledge_index/knowledge_index_node.py | 3 +- api/services/dataset_service.py | 547 +++++++++--------- .../rag_pipeline/rag_pipeline_dsl_service.py | 6 +- .../rag_pipeline_manage_service.py | 14 + .../tools/builtin_tools_manage_service.py | 64 +- api/services/tools/tools_transform_service.py | 110 +--- 22 files changed, 390 insertions(+), 1496 deletions(-) delete mode 100644 api/core/datasource/datasource_engine.py delete mode 100644 api/core/datasource/entities/constants.py delete mode 100644 api/core/datasource/entities/values.py create mode 100644 api/services/rag_pipeline/rag_pipeline_manage_service.py diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 44296d5a31..cc07084dea 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -101,7 +101,9 @@ class CustomizedPipelineTemplateApi(Resource): @enterprise_license_required def post(self, template_id: str): with Session(db.engine) as session: - template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + template = ( + session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + ) if not template: raise ValueError("Customized pipeline template not found.") pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c76014d0a3..c67b897f81 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -43,7 +43,7 @@ from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.rag_pipeline import RagPipelineService -from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) @@ -711,14 +711,7 @@ class DatasourceListApi(Resource): tenant_id = user.current_tenant_id - return jsonable_encoder( - [ - provider.to_dict() - for provider in BuiltinToolManageService.list_rag_pipeline_datasources( - tenant_id, - ) - ] - ) + return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) api.add_resource( diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 8fb89e1172..15d9e7d9ba 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,12 +1,9 @@ -from collections.abc import Generator -from typing import Any, Optional +from collections.abc import Mapping +from typing import Any from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, - DatasourceInvokeMessage, - DatasourceParameter, - DatasourceProviderType, ) from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -16,7 +13,6 @@ class DatasourcePlugin: tenant_id: str icon: str plugin_unique_identifier: str - runtime_parameters: Optional[list[DatasourceParameter]] entity: DatasourceEntity runtime: DatasourceRuntime @@ -33,49 +29,41 @@ class DatasourcePlugin: self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters = None - - def datasource_provider_type(self) -> DatasourceProviderType: - return DatasourceProviderType.RAG_PIPELINE def _invoke_first_step( self, user_id: str, datasource_parameters: dict[str, Any], - rag_pipeline_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Mapping[str, Any]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - yield from manager.invoke_first_step( + return manager.invoke_first_step( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, datasource_parameters=datasource_parameters, - rag_pipeline_id=rag_pipeline_id, ) def _invoke_second_step( self, user_id: str, datasource_parameters: dict[str, Any], - rag_pipeline_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Mapping[str, Any]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - yield from manager.invoke( + return manager.invoke_second_step( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, datasource_parameters=datasource_parameters, - rag_pipeline_id=rag_pipeline_id, ) def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": @@ -86,28 +74,3 @@ class DatasourcePlugin: icon=self.icon, plugin_unique_identifier=self.plugin_unique_identifier, ) - - def get_runtime_parameters( - self, - rag_pipeline_id: Optional[str] = None, - ) -> list[DatasourceParameter]: - """ - get the runtime parameters - """ - if not self.entity.has_runtime_parameters: - return self.entity.parameters - - if self.runtime_parameters is not None: - return self.runtime_parameters - - manager = PluginDatasourceManager() - self.runtime_parameters = manager.get_runtime_parameters( - tenant_id=self.tenant_id, - user_id="", - provider=self.entity.identity.provider, - datasource=self.entity.identity.name, - credentials=self.runtime.credentials, - rag_pipeline_id=rag_pipeline_id, - ) - - return self.runtime_parameters diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index ef3382b948..13804f53d9 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -2,7 +2,7 @@ from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin from core.entities.provider_entities import ProviderConfig from core.plugin.impl.tool import PluginToolManager from core.tools.errors import ToolProviderCredentialValidationError @@ -22,15 +22,6 @@ class DatasourcePluginProviderController: self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier - @property - def provider_type(self) -> DatasourceProviderType: - """ - returns the type of the provider - - :return: type of the provider - """ - return DatasourceProviderType.RAG_PIPELINE - @property def need_credentials(self) -> bool: """ diff --git a/api/core/datasource/datasource_engine.py b/api/core/datasource/datasource_engine.py deleted file mode 100644 index c193c4c629..0000000000 --- a/api/core/datasource/datasource_engine.py +++ /dev/null @@ -1,224 +0,0 @@ -import json -from collections.abc import Generator, Iterable -from mimetypes import guess_type -from typing import Any, Optional, cast - -from yarl import URL - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.datasource.__base.datasource_plugin import DatasourcePlugin -from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, - DatasourceInvokeMessageBinary, -) -from core.file import FileType -from core.file.models import FileTransferMethod -from extensions.ext_database import db -from models.enums import CreatedByRole -from models.model import Message, MessageFile - - -class DatasourceEngine: - """ - Datasource runtime engine take care of the datasource executions. - """ - - @staticmethod - def invoke_first_step( - datasource: DatasourcePlugin, - datasource_parameters: dict[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: - """ - Workflow invokes the datasource with the given arguments. - """ - try: - # hit the callback handler - workflow_tool_callback.on_datasource_start( - datasource_name=datasource.entity.identity.name, datasource_inputs=datasource_parameters - ) - - if datasource.runtime and datasource.runtime.runtime_parameters: - datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters} - - response = datasource._invoke_first_step( - user_id=user_id, - datasource_parameters=datasource_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, - ) - - # hit the callback handler - response = workflow_tool_callback.on_datasource_end( - datasource_name=datasource.entity.identity.name, - datasource_inputs=datasource_parameters, - datasource_outputs=response, - ) - - return response - except Exception as e: - workflow_tool_callback.on_tool_error(e) - raise e - - @staticmethod - def invoke_second_step( - datasource: DatasourcePlugin, - datasource_parameters: dict[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - ) -> Generator[DatasourceInvokeMessage, None, None]: - """ - Workflow invokes the datasource with the given arguments. - """ - try: - response = datasource._invoke_second_step( - user_id=user_id, - datasource_parameters=datasource_parameters, - ) - - return response - except Exception as e: - workflow_tool_callback.on_tool_error(e) - raise e - - @staticmethod - def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str: - """ - Handle datasource response - """ - result = "" - for response in datasource_response: - if response.type == DatasourceInvokeMessage.MessageType.TEXT: - result += cast(DatasourceInvokeMessage.TextMessage, response.message).text - elif response.type == DatasourceInvokeMessage.MessageType.LINK: - result += ( - f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}." - + " please tell user to check it." - ) - elif response.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, - }: - result += ( - "image has been created and sent to user already, " - + "you do not need to create it, just tell the user to check it now." - ) - elif response.type == DatasourceInvokeMessage.MessageType.JSON: - result = json.dumps( - cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False - ) - else: - result += str(response.message) - - return result - - @staticmethod - def _extract_datasource_response_binary_and_text( - datasource_response: list[DatasourceInvokeMessage], - ) -> Generator[DatasourceInvokeMessageBinary, None, None]: - """ - Extract datasource response binary - """ - for response in datasource_response: - if response.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, - }: - mimetype = None - if not response.meta: - raise ValueError("missing meta data") - if response.meta.get("mime_type"): - mimetype = response.meta.get("mime_type") - else: - try: - url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text) - extension = url.suffix - guess_type_result, _ = guess_type(f"a{extension}") - if guess_type_result: - mimetype = guess_type_result - except Exception: - pass - - if not mimetype: - mimetype = "image/jpeg" - - yield DatasourceInvokeMessageBinary( - mimetype=response.meta.get("mime_type", "image/jpeg"), - url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, - ) - elif response.type == DatasourceInvokeMessage.MessageType.BLOB: - if not response.meta: - raise ValueError("missing meta data") - - yield DatasourceInvokeMessageBinary( - mimetype=response.meta.get("mime_type", "application/octet-stream"), - url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, - ) - elif response.type == DatasourceInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and "mime_type" in response.meta: - yield DatasourceInvokeMessageBinary( - mimetype=response.meta.get("mime_type", "application/octet-stream") - if response.meta - else "application/octet-stream", - url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, - ) - - @staticmethod - def _create_message_files( - datasource_messages: Iterable[DatasourceInvokeMessageBinary], - agent_message: Message, - invoke_from: InvokeFrom, - user_id: str, - ) -> list[str]: - """ - Create message file - - :return: message file ids - """ - result = [] - - for message in datasource_messages: - if "image" in message.mimetype: - file_type = FileType.IMAGE - elif "video" in message.mimetype: - file_type = FileType.VIDEO - elif "audio" in message.mimetype: - file_type = FileType.AUDIO - elif "text" in message.mimetype or "pdf" in message.mimetype: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - - # extract tool file id from url - tool_file_id = message.url.split("/")[-1].split(".")[0] - message_file = MessageFile( - message_id=agent_message.id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", - url=message.url, - upload_file_id=tool_file_id, - created_by_role=( - CreatedByRole.ACCOUNT - if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER - ), - created_by=user_id, - ) - - db.session.add(message_file) - db.session.commit() - db.session.refresh(message_file) - - result.append(message_file.id) - - db.session.close() - - return result diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index fa141a679a..c865b557f9 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,9 +6,8 @@ import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.common_entities import I18nObject -from core.datasource.entities.datasource_entities import DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError -from core.plugin.impl.tool import PluginToolManager +from core.plugin.impl.datasource import PluginDatasourceManager logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ class DatasourceManager: if provider in datasource_plugin_providers: return datasource_plugin_providers[provider] - manager = PluginToolManager() + manager = PluginDatasourceManager() provider_entity = manager.fetch_datasource_provider(tenant_id, provider) if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") @@ -55,7 +54,6 @@ class DatasourceManager: @classmethod def get_datasource_runtime( cls, - provider_type: DatasourceProviderType, provider_id: str, datasource_name: str, tenant_id: str, @@ -70,18 +68,15 @@ class DatasourceManager: :return: the datasource plugin """ - if provider_type == DatasourceProviderType.RAG_PIPELINE: - return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) - else: - raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found") + return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) @classmethod def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: """ list all the datasource providers """ - manager = PluginToolManager() - provider_entities = manager.fetch_datasources(tenant_id) + manager = PluginDatasourceManager() + provider_entities = manager.fetch_datasource_providers(tenant_id) return [ DatasourcePluginProviderController( entity=provider.declaration, diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 2d42484a30..8d6bed41fa 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -4,7 +4,6 @@ from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType @@ -14,7 +13,7 @@ class DatasourceApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] = None + parameters: Optional[list[DatasourceParameter]] = None labels: list[str] = Field(default_factory=list) output_schema: Optional[dict] = None diff --git a/api/core/datasource/entities/constants.py b/api/core/datasource/entities/constants.py deleted file mode 100644 index a4dbf6f11f..0000000000 --- a/api/core/datasource/entities/constants.py +++ /dev/null @@ -1 +0,0 @@ -DATASOURCE_SELECTOR_MODEL_IDENTITY = "__dify__datasource_selector__" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index aa31a7f86a..e1bcbc323b 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,13 +1,9 @@ -import base64 import enum -from collections.abc import Mapping from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Optional -from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator +from pydantic import BaseModel, Field, ValidationInfo, field_validator -from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY -from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -17,25 +13,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject - - -class ToolLabelEnum(Enum): - SEARCH = "search" - IMAGE = "image" - VIDEOS = "videos" - WEATHER = "weather" - FINANCE = "finance" - DESIGN = "design" - TRAVEL = "travel" - SOCIAL = "social" - NEWS = "news" - MEDICAL = "medical" - PRODUCTIVITY = "productivity" - EDUCATION = "education" - BUSINESS = "business" - ENTERTAINMENT = "entertainment" - UTILITIES = "utilities" - OTHER = "other" +from core.tools.entities.tool_entities import ToolProviderEntity class DatasourceProviderType(enum.StrEnum): @@ -43,7 +21,9 @@ class DatasourceProviderType(enum.StrEnum): Enum class for datasource provider """ - RAG_PIPELINE = "rag_pipeline" + ONLINE_DOCUMENT = "online_document" + LOCAL_FILE = "local_file" + WEBSITE = "website" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -59,153 +39,6 @@ class DatasourceProviderType(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -class ApiProviderSchemaType(Enum): - """ - Enum class for api provider schema type. - """ - - OPENAPI = "openapi" - SWAGGER = "swagger" - OPENAI_PLUGIN = "openai_plugin" - OPENAI_ACTIONS = "openai_actions" - - @classmethod - def value_of(cls, value: str) -> "ApiProviderSchemaType": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - - -class ApiProviderAuthType(Enum): - """ - Enum class for api provider auth type. - """ - - NONE = "none" - API_KEY = "api_key" - - @classmethod - def value_of(cls, value: str) -> "ApiProviderAuthType": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - - -class DatasourceInvokeMessage(BaseModel): - class TextMessage(BaseModel): - text: str - - class JsonMessage(BaseModel): - json_object: dict - - class BlobMessage(BaseModel): - blob: bytes - - class FileMessage(BaseModel): - pass - - class VariableMessage(BaseModel): - variable_name: str = Field(..., description="The name of the variable") - variable_value: Any = Field(..., description="The value of the variable") - stream: bool = Field(default=False, description="Whether the variable is streamed") - - @model_validator(mode="before") - @classmethod - def transform_variable_value(cls, values) -> Any: - """ - Only basic types and lists are allowed. - """ - value = values.get("variable_value") - if not isinstance(value, dict | list | str | int | float | bool): - raise ValueError("Only basic types and lists are allowed.") - - # if stream is true, the value must be a string - if values.get("stream"): - if not isinstance(value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - - return values - - @field_validator("variable_name", mode="before") - @classmethod - def transform_variable_name(cls, value: str) -> str: - """ - The variable name must be a string. - """ - if value in {"json", "text", "files"}: - raise ValueError(f"The variable name '{value}' is reserved.") - return value - - class LogMessage(BaseModel): - class LogStatus(Enum): - START = "start" - ERROR = "error" - SUCCESS = "success" - - id: str - label: str = Field(..., description="The label of the log") - parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") - error: Optional[str] = Field(default=None, description="The error message") - status: LogStatus = Field(..., description="The status of the log") - data: Mapping[str, Any] = Field(..., description="Detailed log data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") - - class MessageType(Enum): - TEXT = "text" - IMAGE = "image" - LINK = "link" - BLOB = "blob" - JSON = "json" - IMAGE_LINK = "image_link" - BINARY_LINK = "binary_link" - VARIABLE = "variable" - FILE = "file" - LOG = "log" - - type: MessageType = MessageType.TEXT - """ - plain text, image url or link url - """ - message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage - meta: dict[str, Any] | None = None - - @field_validator("message", mode="before") - @classmethod - def decode_blob_message(cls, v): - if isinstance(v, dict) and "blob" in v: - try: - v["blob"] = base64.b64decode(v["blob"]) - except Exception: - pass - return v - - @field_serializer("message") - def serialize_message(self, v): - if isinstance(v, self.BlobMessage): - return {"blob": base64.b64encode(v.blob).decode("utf-8")} - return v - - -class DatasourceInvokeMessageBinary(BaseModel): - mimetype: str = Field(..., description="The mimetype of the binary") - url: str = Field(..., description="The url of the binary") - file_var: Optional[dict[str, Any]] = None - - class DatasourceParameter(PluginParameter): """ Overrides type @@ -223,8 +56,6 @@ class DatasourceParameter(PluginParameter): SECRET_INPUT = PluginParameterType.SECRET_INPUT.value FILE = PluginParameterType.FILE.value FILES = PluginParameterType.FILES.value - APP_SELECTOR = PluginParameterType.APP_SELECTOR.value - MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value # deprecated, should not use. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value @@ -235,21 +66,13 @@ class DatasourceParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class DatasourceParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM - type: DatasourceParameterType = Field(..., description="The type of the parameter") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") - form: DatasourceParameterForm = Field(..., description="The form of the parameter, schema/form/llm") - llm_description: Optional[str] = None + description: I18nObject = Field(..., description="The description of the parameter") @classmethod def get_simple_instance( cls, name: str, - llm_description: str, typ: DatasourceParameterType, required: bool, options: Optional[list[str]] = None, @@ -277,30 +100,16 @@ class DatasourceParameter(PluginParameter): name=name, label=I18nObject(en_US="", zh_Hans=""), placeholder=None, - human_description=I18nObject(en_US="", zh_Hans=""), type=typ, - form=cls.ToolParameterForm.LLM, - llm_description=llm_description, required=required, options=option_objs, + description=I18nObject(en_US="", zh_Hans=""), ) def init_frontend_parameter(self, value: Any): return init_frontend_parameter(self, self.type, value) -class ToolProviderIdentity(BaseModel): - author: str = Field(..., description="The author of the tool") - name: str = Field(..., description="The name of the tool") - description: I18nObject = Field(..., description="The description of the tool") - icon: str = Field(..., description="The icon of the tool") - label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( - default=[], - description="The tags of the tool", - ) - - class DatasourceIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -327,26 +136,18 @@ class DatasourceEntity(BaseModel): return v or [] -class ToolProviderEntity(BaseModel): - identity: ToolProviderIdentity - plugin_id: Optional[str] = None - credentials_schema: list[ProviderConfig] = Field(default_factory=list) +class DatasourceProviderEntity(ToolProviderEntity): + """ + Datasource provider entity + """ + + provider_type: DatasourceProviderType -class DatasourceProviderEntityWithPlugin(ToolProviderEntity): +class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): datasources: list[DatasourceEntity] = Field(default_factory=list) -class WorkflowToolParameterConfiguration(BaseModel): - """ - Workflow tool configuration - """ - - name: str = Field(..., description="The name of the parameter") - description: str = Field(..., description="The description of the parameter") - form: DatasourceParameter.DatasourceParameterForm = Field(..., description="The form of the parameter") - - class DatasourceInvokeMeta(BaseModel): """ Datasource invoke meta @@ -394,24 +195,3 @@ class DatasourceInvokeFrom(Enum): """ RAG_PIPELINE = "rag_pipeline" - - -class DatasourceSelector(BaseModel): - dify_model_identity: str = DATASOURCE_SELECTOR_MODEL_IDENTITY - - class Parameter(BaseModel): - name: str = Field(..., description="The name of the parameter") - type: DatasourceParameter.DatasourceParameterType = Field(..., description="The type of the parameter") - required: bool = Field(..., description="Whether the parameter is required") - description: str = Field(..., description="The description of the parameter") - default: Optional[Union[int, float, str]] = None - options: Optional[list[PluginParameterOption]] = None - - provider_id: str = Field(..., description="The id of the provider") - datasource_name: str = Field(..., description="The name of the datasource") - datasource_description: str = Field(..., description="The description of the datasource") - datasource_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") - datasource_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") - - def to_plugin_parameter(self) -> dict[str, Any]: - return self.model_dump() diff --git a/api/core/datasource/entities/values.py b/api/core/datasource/entities/values.py deleted file mode 100644 index f460df7e25..0000000000 --- a/api/core/datasource/entities/values.py +++ /dev/null @@ -1,111 +0,0 @@ -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum - -ICONS = { - ToolLabelEnum.SEARCH: """ - -""", # noqa: E501 - ToolLabelEnum.IMAGE: """ - -""", # noqa: E501 - ToolLabelEnum.VIDEOS: """ - -""", # noqa: E501 - ToolLabelEnum.WEATHER: """ - -""", # noqa: E501 - ToolLabelEnum.FINANCE: """ - -""", # noqa: E501 - ToolLabelEnum.DESIGN: """ - -""", # noqa: E501 - ToolLabelEnum.TRAVEL: """ - -""", # noqa: E501 - ToolLabelEnum.SOCIAL: """ - -""", # noqa: E501 - ToolLabelEnum.NEWS: """ - -""", # noqa: E501 - ToolLabelEnum.MEDICAL: """ - -""", # noqa: E501 - ToolLabelEnum.PRODUCTIVITY: """ - -""", # noqa: E501 - ToolLabelEnum.EDUCATION: """ - -""", # noqa: E501 - ToolLabelEnum.BUSINESS: """ - -""", # noqa: E501 - ToolLabelEnum.ENTERTAINMENT: """ - -""", # noqa: E501 - ToolLabelEnum.UTILITIES: """ - -""", # noqa: E501 - ToolLabelEnum.OTHER: """ - -""", # noqa: E501 -} - -default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel( - name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] - ), - ToolLabelEnum.IMAGE: ToolLabel( - name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] - ), - ToolLabelEnum.VIDEOS: ToolLabel( - name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] - ), - ToolLabelEnum.WEATHER: ToolLabel( - name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] - ), - ToolLabelEnum.FINANCE: ToolLabel( - name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] - ), - ToolLabelEnum.DESIGN: ToolLabel( - name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] - ), - ToolLabelEnum.TRAVEL: ToolLabel( - name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] - ), - ToolLabelEnum.SOCIAL: ToolLabel( - name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] - ), - ToolLabelEnum.NEWS: ToolLabel( - name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] - ), - ToolLabelEnum.MEDICAL: ToolLabel( - name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] - ), - ToolLabelEnum.PRODUCTIVITY: ToolLabel( - name="productivity", - label=I18nObject(en_US="Productivity", zh_Hans="生产力"), - icon=ICONS[ToolLabelEnum.PRODUCTIVITY], - ), - ToolLabelEnum.EDUCATION: ToolLabel( - name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] - ), - ToolLabelEnum.BUSINESS: ToolLabel( - name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] - ), - ToolLabelEnum.ENTERTAINMENT: ToolLabel( - name="entertainment", - label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), - icon=ICONS[ToolLabelEnum.ENTERTAINMENT], - ), - ToolLabelEnum.UTILITIES: ToolLabel( - name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] - ), - ToolLabelEnum.OTHER: ToolLabel( - name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] - ), -} - -default_tool_labels = [v for k, v in default_tool_label_dict.items()] -default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index c69fa2fe32..922e65d725 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,16 +1,16 @@ -from collections.abc import Generator -from typing import Any, Optional - -from pydantic import BaseModel +from collections.abc import Mapping +from typing import Any from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.entities.plugin_daemon import ( + PluginBasicBooleanResponse, + PluginDatasourceProviderEntity, +) from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter class PluginDatasourceManager(BasePluginClient): - def fetch_datasource_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: + def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: """ Fetch datasource providers for the given tenant. """ @@ -27,7 +27,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/datasources", - list[PluginToolProviderEntity], + list[PluginDatasourceProviderEntity], params={"page": 1, "page_size": 256}, transformer=transformer, ) @@ -36,12 +36,12 @@ class PluginDatasourceManager(BasePluginClient): provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" # override the provider name for each tool to plugin_id/provider_name - for tool in provider.declaration.tools: - tool.identity.provider = provider.declaration.identity.name + for datasource in provider.declaration.datasources: + datasource.identity.provider = provider.declaration.identity.name return response - def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: + def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ Fetch datasource provider for the given tenant and plugin. """ @@ -58,7 +58,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/datasources", - PluginToolProviderEntity, + PluginDatasourceProviderEntity, params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, transformer=transformer, ) @@ -66,8 +66,8 @@ class PluginDatasourceManager(BasePluginClient): response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" # override the provider name for each tool to plugin_id/provider_name - for tool in response.declaration.tools: - tool.identity.provider = response.declaration.identity.name + for datasource in response.declaration.datasources: + datasource.identity.provider = response.declaration.identity.name return response @@ -79,7 +79,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_name: str, credentials: dict[str, Any], datasource_parameters: dict[str, Any], - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Mapping[str, Any]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -88,8 +88,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages", - ToolInvokeMessage, + f"plugin/{tenant_id}/dispatch/datasource/first_step", + dict, data={ "user_id": user_id, "data": { @@ -104,7 +104,10 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - return response + for resp in response: + return resp + + raise Exception("No response from plugin daemon") def invoke_second_step( self, @@ -114,7 +117,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_name: str, credentials: dict[str, Any], datasource_parameters: dict[str, Any], - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Mapping[str, Any]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -123,8 +126,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/invoke_second_step", - ToolInvokeMessage, + f"plugin/{tenant_id}/dispatch/datasource/second_step", + dict, data={ "user_id": user_id, "data": { @@ -139,7 +142,10 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - return response + for resp in response: + return resp + + raise Exception("No response from plugin daemon") def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] @@ -151,7 +157,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/tool/validate_credentials", + f"plugin/{tenant_id}/dispatch/datasource/validate_credentials", PluginBasicBooleanResponse, data={ "user_id": user_id, @@ -170,48 +176,3 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False - - def get_runtime_parameters( - self, - tenant_id: str, - user_id: str, - provider: str, - credentials: dict[str, Any], - datasource: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - """ - get the runtime parameters of the datasource - """ - datasource_provider_id = GenericProviderID(provider) - - class RuntimeParametersResponse(BaseModel): - parameters: list[ToolParameter] - - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/datasource/get_runtime_parameters", - RuntimeParametersResponse, - data={ - "user_id": user_id, - "conversation_id": conversation_id, - "app_id": app_id, - "message_id": message_id, - "data": { - "provider": datasource_provider_id.provider_name, - "datasource": datasource, - "credentials": credentials, - }, - }, - headers={ - "X-Plugin-ID": datasource_provider_id.plugin_id, - "Content-Type": "application/json", - }, - ) - - for resp in response: - return resp.parameters - - return [] diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 54f5418bb4..bb9c00005c 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -3,10 +3,9 @@ from typing import Any, Optional from pydantic import BaseModel -from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, - PluginDatasourceProviderEntity, PluginToolProviderEntity, ) from core.plugin.impl.base import BasePluginClient @@ -45,67 +44,6 @@ class PluginToolManager(BasePluginClient): return response - def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: - """ - Fetch datasources for the given tenant. - """ - - def transformer(json_response: dict[str, Any]) -> dict: - for provider in json_response.get("data", []): - declaration = provider.get("declaration", {}) or {} - provider_name = declaration.get("identity", {}).get("name") - for tool in declaration.get("tools", []): - tool["identity"]["provider"] = provider_name - - return json_response - - response = self._request_with_plugin_daemon_response( - "GET", - f"plugin/{tenant_id}/management/datasources", - list[PluginToolProviderEntity], - params={"page": 1, "page_size": 256}, - transformer=transformer, - ) - - for provider in response: - provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - - # override the provider name for each tool to plugin_id/provider_name - for tool in provider.declaration.tools: - tool.identity.provider = provider.declaration.identity.name - - return response - - def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: - """ - Fetch datasource provider for the given tenant and plugin. - """ - datasource_provider_id = DatasourceProviderID(provider) - - def transformer(json_response: dict[str, Any]) -> dict: - data = json_response.get("data") - if data: - for tool in data.get("declaration", {}).get("tools", []): - tool["identity"]["provider"] = datasource_provider_id.provider_name - - return json_response - - response = self._request_with_plugin_daemon_response( - "GET", - f"plugin/{tenant_id}/management/datasource", - PluginDatasourceProviderEntity, - params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id}, - transformer=transformer, - ) - - response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" - - # override the provider name for each tool to plugin_id/provider_name - for tool in response.declaration.tools: - tool.identity.provider = response.declaration.identity.name - - return response - def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: """ Fetch tool provider for the given tenant and plugin. diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 49bf7c308a..6876285b31 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,6 +1,5 @@ from typing import Any -from core.datasource.entities.datasource_entities import DatasourceSelector from core.file.models import File from core.tools.entities.tool_entities import ToolSelector @@ -19,10 +18,4 @@ def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, parameters[parameter_name] = [] for p in parameter: parameters[parameter_name].append(p.to_plugin_parameter()) - elif isinstance(parameter, DatasourceSelector): - parameters[parameter_name] = parameter.to_plugin_parameter() - elif isinstance(parameter, list) and all(isinstance(p, DatasourceSelector) for p in parameter): - parameters[parameter_name] = [] - for p in parameter: - parameters[parameter_name].append(p.to_plugin_parameter()) return parameters diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 682a32d26f..aa2661fe63 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts -from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -496,31 +495,6 @@ class ToolManager: # get plugin providers yield from cls.list_plugin_providers(tenant_id) - @classmethod - def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: - """ - list all the datasource providers - """ - manager = PluginToolManager() - provider_entities = manager.fetch_datasources(tenant_id) - return [ - DatasourcePluginProviderController( - entity=provider.declaration, - plugin_id=provider.plugin_id, - plugin_unique_identifier=provider.plugin_unique_identifier, - tenant_id=tenant_id, - ) - for provider in provider_entities - ] - - @classmethod - def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]: - """ - list all the builtin datasources - """ - # get builtin datasources - yield from cls.list_datasource_providers(tenant_id) - @classmethod def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 8ecf66c0d6..e7d4da8426 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,35 +1,24 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.datasource.datasource_engine import DatasourceEngine -from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter -from core.datasource.errors import DatasourceInvokeError -from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from core.file import File, FileTransferMethod -from core.plugin.manager.exc import PluginDaemonClientSideError -from core.plugin.manager.plugin import PluginInstallationManager +from core.datasource.entities.datasource_entities import ( + DatasourceParameter, +) +from core.file import File +from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.event import RunCompletedEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from factories import file_factory -from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus -from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import DatasourceNodeData -from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError +from .exc import DatasourceNodeError, DatasourceParameterError class DatasourceNode(BaseNode[DatasourceNodeData]): @@ -49,7 +38,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): # fetch datasource icon datasource_info = { - "provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id, "plugin_unique_identifier": node_data.plugin_unique_identifier, } @@ -58,8 +46,10 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): try: from core.datasource.datasource_manager import DatasourceManager - datasource_runtime = DatasourceManager.get_workflow_datasource_runtime( - self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=node_data.provider_id, + datasource_name=node_data.datasource_name, + tenant_id=self.tenant_id, ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -74,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return # get parameters - datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or [] + datasource_parameters = datasource_runtime.entity.parameters parameters = self._generate_parameters( datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -91,15 +81,20 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) try: - message_stream = DatasourceEngine.generic_invoke( - datasource=datasource_runtime, - datasource_parameters=parameters, + # TODO: handle result + result = datasource_runtime._invoke_second_step( user_id=self.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), - workflow_call_depth=self.workflow_call_depth, - thread_pool_id=self.thread_pool_id, - app_id=self.app_id, - conversation_id=conversation_id.text if conversation_id else None, + datasource_parameters=parameters, + ) + except PluginDaemonClientSideError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", + error_type=type(e).__name__, + ) ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -113,20 +108,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) return - try: - # convert datasource messages - yield from self._transform_message(message_stream, datasource_info, parameters_for_log) - except (PluginDaemonClientSideError, DatasourceInvokeError) as e: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to transform datasource message: {str(e)}", - error_type=type(e).__name__, - ) - ) - def _generate_parameters( self, *, @@ -175,200 +156,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _transform_message( - self, - messages: Generator[DatasourceInvokeMessage, None, None], - datasource_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - ) -> Generator: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( - messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json: list[dict] = [] - - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} - - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.BINARY_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - ) - elif message.type == DatasourceInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - text += message.message.text - yield RunStreamChunkEvent( - chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] - ) - elif message.type == DatasourceInvokeMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) - if self.node_type == NodeType.AGENT: - msg_metadata = message.message.json_object.pop("execution_metadata", {}) - agent_execution_metadata = { - key: value - for key, value in msg_metadata.items() - if key in NodeRunMetadataKey.__members__.values() - } - json.append(message.message.json_object) - elif message.type == DatasourceInvokeMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) - elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield RunStreamChunkEvent( - chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] - ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceInvokeMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) - elif message.type == DatasourceInvokeMessage.MessageType.LOG: - assert isinstance(message.message, DatasourceInvokeMessage.LogMessage) - if message.message.metadata: - icon = datasource_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstallationManager() - plugins = manager.list_plugins(self.tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - self.user_id, - self.tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - except StopIteration: - pass - - dict_metadata["icon"] = icon - message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - id=message.message.id, - node_execution_id=self.id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=self.node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.id == agent_log.id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, - metadata={ - **agent_execution_metadata, - NodeRunMetadataKey.DATASOURCE_INFO: datasource_info, - NodeRunMetadataKey.AGENT_LOG: agent_logs, - }, - inputs=parameters_for_log, - ) - ) - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 66e8adc431..68aa9fa34c 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -3,17 +3,15 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.tools.entities.tool_entities import ToolProviderType from core.workflow.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): provider_id: str - provider_type: ToolProviderType provider_name: str # redundancy - tool_name: str + datasource_name: str tool_label: str # redundancy - tool_configurations: dict[str, Any] + datasource_configurations: dict[str, Any] plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index f039b233a5..1fa6c20bf9 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,8 @@ import datetime import logging import time -from typing import Any, cast, Mapping +from collections.abc import Mapping +from typing import Any, cast from flask_login import current_user diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0f5069f052..2e3cb604de 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -245,16 +245,20 @@ class DatasetService: rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) - + pipeline = Pipeline( tenant_id=tenant_id, name=rag_pipeline_dataset_create_entity.name, description=rag_pipeline_dataset_create_entity.description, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(pipeline) db.session.flush() @@ -268,7 +272,7 @@ class DatasetService: runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, created_by=current_user.id, - pipeline_id=pipeline.id + pipeline_id=pipeline.id, ) db.session.add(dataset) db.session.commit() @@ -280,7 +284,11 @@ class DatasetService: rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) @@ -299,7 +307,7 @@ class DatasetService: account=current_user, import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=dataset + dataset=dataset, ) return { "id": rag_pipeline_import_info.id, @@ -1254,281 +1262,282 @@ class DocumentService: return documents, batch - @staticmethod - def save_document_with_dataset_id( - dataset: Dataset, - knowledge_config: KnowledgeConfig, - account: Account | Any, - dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = "web", - ): - # check document limit - features = FeatureService.get_features(current_user.current_tenant_id) + # @staticmethod + # def save_document_with_dataset_id( + # dataset: Dataset, + # knowledge_config: KnowledgeConfig, + # account: Account | Any, + # dataset_process_rule: Optional[DatasetProcessRule] = None, + # created_from: str = "web", + # ): + # # check document limit + # features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: - if not knowledge_config.original_document_id: - count = 0 - if knowledge_config.data_source: - if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: # type: ignore - count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": - website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) # type: ignore - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + # if features.billing.enabled: + # if not knowledge_config.original_document_id: + # count = 0 + # if knowledge_config.data_source: + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + # # type: ignore + # count = len(upload_file_list) + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list + # for notion_info in notion_info_list: # type: ignore + # count = count + len(notion_info.pages) + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + # website_info = knowledge_config.data_source.info_list.website_info_list + # count = len(website_info.urls) # type: ignore + # batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if features.billing.subscription.plan == "sandbox" and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + # if features.billing.subscription.plan == "sandbox" and count > 1: + # raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + # if count > batch_upload_limit: + # raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - DocumentService.check_documents_upload_quota(count, features) + # DocumentService.check_documents_upload_quota(count, features) - # if dataset is empty, update dataset data_source_type - if not dataset.data_source_type: - dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + # # if dataset is empty, update dataset data_source_type + # if not dataset.data_source_type: + # dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore - if not dataset.indexing_technique: - if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: - raise ValueError("Indexing technique is invalid") + # if not dataset.indexing_technique: + # if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + # raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() - if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: - dataset_embedding_model = knowledge_config.embedding_model - dataset_embedding_model_provider = knowledge_config.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model - ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } + # dataset.indexing_technique = knowledge_config.indexing_technique + # if knowledge_config.indexing_technique == "high_quality": + # model_manager = ModelManager() + # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + # dataset_embedding_model = knowledge_config.embedding_model + # dataset_embedding_model_provider = knowledge_config.embedding_model_provider + # else: + # embedding_model = model_manager.get_default_model_instance( + # tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + # ) + # dataset_embedding_model = embedding_model.model + # dataset_embedding_model_provider = embedding_model.provider + # dataset.embedding_model = dataset_embedding_model + # dataset.embedding_model_provider = dataset_embedding_model_provider + # dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + # dataset_embedding_model_provider, dataset_embedding_model + # ) + # dataset.collection_binding_id = dataset_collection_binding.id + # if not dataset.retrieval_model: + # default_retrieval_model = { + # "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + # "reranking_enable": False, + # "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + # "top_k": 2, + # "score_threshold_enabled": False, + # } - dataset.retrieval_model = ( - knowledge_config.retrieval_model.model_dump() - if knowledge_config.retrieval_model - else default_retrieval_model - ) # type: ignore + # dataset.retrieval_model = ( + # knowledge_config.retrieval_model.model_dump() + # if knowledge_config.retrieval_model + # else default_retrieval_model + # ) # type: ignore - documents = [] - if knowledge_config.original_document_id: - document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) - documents.append(document) - batch = document.batch - else: - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) - # save process rule - if not dataset_process_rule: - process_rule = knowledge_config.process_rule - if process_rule: - if process_rule.mode in ("custom", "hierarchical"): - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json() if process_rule.rules else None, - created_by=account.id, - ) - elif process_rule.mode == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - else: - logging.warn( - f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" - ) - return - db.session.add(dataset_process_rule) - db.session.commit() - lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) - with redis_client.lock(lock_name, timeout=600): - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) + # documents = [] + # if knowledge_config.original_document_id: + # document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + # documents.append(document) + # batch = document.batch + # else: + # batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + # # save process rule + # if not dataset_process_rule: + # process_rule = knowledge_config.process_rule + # if process_rule: + # if process_rule.mode in ("custom", "hierarchical"): + # dataset_process_rule = DatasetProcessRule( + # dataset_id=dataset.id, + # mode=process_rule.mode, + # rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + # created_by=account.id, + # ) + # elif process_rule.mode == "automatic": + # dataset_process_rule = DatasetProcessRule( + # dataset_id=dataset.id, + # mode=process_rule.mode, + # rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + # created_by=account.id, + # ) + # else: + # logging.warn( + # f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + # ) + # return + # db.session.add(dataset_process_rule) + # db.session.commit() + # lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + # with redis_client.lock(lock_name, timeout=600): + # position = DocumentService.get_documents_position(dataset.id) + # document_ids = [] + # duplicate_document_ids = [] + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + # for file_id in upload_file_list: + # file = ( + # db.session.query(UploadFile) + # .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + # .first() + # ) - # raise error if file not found - if not file: - raise FileNotExistsError() + # # raise error if file not found + # if not file: + # raise FileNotExistsError() - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if knowledge_config.duplicate: - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() - if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore - if not notion_info_list: - raise ValueError("No notion info list found.") - exist_page_ids = [] - exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info.workspace_id - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ).first() - if not data_source_binding: - raise ValueError("Data source binding not found.") - for page in notion_info.pages: - if page.page_id not in exist_page_ids: - data_source_info = { - "notion_workspace_id": workspace_id, - "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, - "type": page.type, - } - # Truncate page name to 255 characters to prevent DB field length errors - truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - truncated_page_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - else: - exist_document.pop(page.page_id) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore - if not website_info: - raise ValueError("No website info list found.") - urls = website_info.urls - for url in urls: - data_source_info = { - "url": url, - "provider": website_info.provider, - "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - document_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + # file_name = file.name + # data_source_info = { + # "upload_file_id": file_id, + # } + # # check duplicate + # if knowledge_config.duplicate: + # document = Document.query.filter_by( + # dataset_id=dataset.id, + # tenant_id=current_user.current_tenant_id, + # data_source_type="upload_file", + # enabled=True, + # name=file_name, + # ).first() + # if document: + # document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + # document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # document.created_from = created_from + # document.doc_form = knowledge_config.doc_form + # document.doc_language = knowledge_config.doc_language + # document.data_source_info = json.dumps(data_source_info) + # document.batch = batch + # document.indexing_status = "waiting" + # db.session.add(document) + # documents.append(document) + # duplicate_document_ids.append(document.id) + # continue + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # file_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + # if not notion_info_list: + # raise ValueError("No notion info list found.") + # exist_page_ids = [] + # exist_document = {} + # documents = Document.query.filter_by( + # dataset_id=dataset.id, + # tenant_id=current_user.current_tenant_id, + # data_source_type="notion_import", + # enabled=True, + # ).all() + # if documents: + # for document in documents: + # data_source_info = json.loads(document.data_source_info) + # exist_page_ids.append(data_source_info["notion_page_id"]) + # exist_document[data_source_info["notion_page_id"]] = document.id + # for notion_info in notion_info_list: + # workspace_id = notion_info.workspace_id + # data_source_binding = DataSourceOauthBinding.query.filter( + # db.and_( + # DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + # DataSourceOauthBinding.provider == "notion", + # DataSourceOauthBinding.disabled == False, + # DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + # ) + # ).first() + # if not data_source_binding: + # raise ValueError("Data source binding not found.") + # for page in notion_info.pages: + # if page.page_id not in exist_page_ids: + # data_source_info = { + # "notion_workspace_id": workspace_id, + # "notion_page_id": page.page_id, + # "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + # "type": page.type, + # } + # # Truncate page name to 255 characters to prevent DB field length errors + # truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # truncated_page_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # else: + # exist_document.pop(page.page_id) + # # delete not selected documents + # if len(exist_document) > 0: + # clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + # website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + # if not website_info: + # raise ValueError("No website info list found.") + # urls = website_info.urls + # for url in urls: + # data_source_info = { + # "url": url, + # "provider": website_info.provider, + # "job_id": website_info.job_id, + # "only_main_content": website_info.only_main_content, + # "mode": "crawl", + # } + # if len(url) > 255: + # document_name = url[:200] + "..." + # else: + # document_name = url + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # document_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # db.session.commit() - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # # trigger async task + # if document_ids: + # document_indexing_task.delay(dataset.id, document_ids) + # if duplicate_document_ids: + # duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch + # return documents, batch @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 3664c988e5..1c6dac55be 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -309,8 +309,10 @@ class RagPipelineDslService: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( - DatasetCollectionBinding.provider_name == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, - DatasetCollectionBinding.model_name == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + DatasetCollectionBinding.provider_name + == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + DatasetCollectionBinding.model_name + == knowledge_configuration.index_method.embedding_setting.embedding_model_name, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py new file mode 100644 index 0000000000..4d8d69f913 --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -0,0 +1,14 @@ +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity +from core.plugin.impl.datasource import PluginDatasourceManager + + +class RagPipelineManageService: + @staticmethod + def list_rag_pipeline_datasources(tenant_id: str) -> list[PluginDatasourceProviderEntity]: + """ + list rag pipeline datasources + """ + + # get all builtin providers + manager = PluginDatasourceManager() + return manager.fetch_datasource_providers(tenant_id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index daf3773309..3ccd14415d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,7 +5,6 @@ from pathlib import Path from sqlalchemy.orm import Session from configs import dify_config -from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -17,7 +16,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinDatasourceProvider, BuiltinToolProvider +from models.tools import BuiltinToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -287,67 +286,6 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) - @staticmethod - def list_rag_pipeline_datasources(tenant_id: str) -> list[DatasourceProviderApiEntity]: - """ - list rag pipeline datasources - """ - # get all builtin providers - datasource_provider_controllers = ToolManager.list_datasource_providers(tenant_id) - - with db.session.no_autoflush: - # get all user added providers - db_providers: list[BuiltinDatasourceProvider] = ( - db.session.query(BuiltinDatasourceProvider) - .filter(BuiltinDatasourceProvider.tenant_id == tenant_id) - .all() - or [] - ) - - # find provider - def find_provider(provider): - return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) - - result: list[DatasourceProviderApiEntity] = [] - - for provider_controller in datasource_provider_controllers: - try: - # handle include, exclude - if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore - data=provider_controller, - name_func=lambda x: x.identity.name, - ): - continue - - # convert provider controller to user provider - user_builtin_provider = ToolTransformService.builtin_datasource_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=find_provider(provider_controller.entity.identity.name), - decrypt_credentials=True, - ) - - # add icon - ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) - - datasources = provider_controller.get_datasources() - for datasource in datasources or []: - user_builtin_provider.datasources.append( - ToolTransformService.convert_datasource_entity_to_api_entity( - tenant_id=tenant_id, - datasource=datasource, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller), - ) - ) - - result.append(user_builtin_provider) - except Exception as e: - raise e - - return result - @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e0c1ce7217..367121125b 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,11 +5,6 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config -from core.datasource.__base.datasource_plugin import DatasourcePlugin -from core.datasource.__base.datasource_provider import DatasourcePluginProviderController -from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity -from core.datasource.entities.datasource_entities import DatasourceProviderType from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -26,7 +21,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from models.tools import ApiToolProvider, BuiltinDatasourceProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -145,64 +140,6 @@ class ToolTransformService: return result - @classmethod - def builtin_datasource_provider_to_user_provider( - cls, - provider_controller: DatasourcePluginProviderController, - db_provider: Optional[BuiltinDatasourceProvider], - decrypt_credentials: bool = True, - ) -> DatasourceProviderApiEntity: - """ - convert provider controller to user provider - """ - result = DatasourceProviderApiEntity( - id=provider_controller.entity.identity.name, - author=provider_controller.entity.identity.author, - name=provider_controller.entity.identity.name, - description=provider_controller.entity.identity.description, - icon=provider_controller.entity.identity.icon, - label=provider_controller.entity.identity.label, - type=DatasourceProviderType.RAG_PIPELINE, - masked_credentials={}, - is_team_authorization=False, - plugin_id=provider_controller.plugin_id, - plugin_unique_identifier=provider_controller.plugin_unique_identifier, - datasources=[], - ) - - # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} - - for name, value in schema.items(): - if result.masked_credentials: - result.masked_credentials[name] = "" - - # check if the provider need credentials - if not provider_controller.need_credentials: - result.is_team_authorization = True - result.allow_delete = False - elif db_provider: - result.is_team_authorization = True - - if decrypt_credentials: - credentials = db_provider.credentials - - # init tool configuration - tool_configuration = ProviderConfigEncrypter( - tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) - - result.masked_credentials = masked_credentials - result.original_credentials = decrypted_credentials - - return result - @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -367,48 +304,3 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) - - @staticmethod - def convert_datasource_entity_to_api_entity( - datasource: DatasourcePlugin, - tenant_id: str, - credentials: dict | None = None, - labels: list[str] | None = None, - ) -> DatasourceApiEntity: - """ - convert tool to user tool - """ - # fork tool runtime - datasource = datasource.fork_datasource_runtime( - runtime=DatasourceRuntime( - credentials=credentials or {}, - tenant_id=tenant_id, - ) - ) - - # get datasource parameters - parameters = datasource.entity.parameters or [] - # get datasource runtime parameters - runtime_parameters = datasource.get_runtime_parameters() - # override parameters - current_parameters = parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) - - return DatasourceApiEntity( - author=datasource.entity.identity.author, - name=datasource.entity.identity.name, - label=datasource.entity.identity.label, - description=datasource.entity.description.human if datasource.entity.description else I18nObject(en_US=""), - output_schema=datasource.entity.output_schema, - parameters=current_parameters, - labels=labels or [], - ) From ba52bf27c19a36a5bca1c94f9ceb8d34e98edee5 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 20 May 2025 14:57:26 +0800 Subject: [PATCH 023/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 27 +++++++----- .../agent_tool_callback_handler.py | 30 ------------- .../entities/datasource_entities.py | 2 +- .../nodes/knowledge_index/entities.py | 2 +- api/factories/variable_factory.py | 43 +++++++++++++++++-- api/fields/workflow_fields.py | 17 ++++++-- api/models/workflow.py | 10 ++--- api/services/rag_pipeline/rag_pipeline.py | 16 ++++--- 8 files changed, 85 insertions(+), 62 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c67b897f81..26cd3dd90b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -93,7 +93,7 @@ class DraftRagPipelineApi(Resource): parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") - parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json") + parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -101,8 +101,8 @@ class DraftRagPipelineApi(Resource): if "graph" not in data or "features" not in data: raise ValueError("graph or features not found in data") - if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): - raise ValueError("graph or features is not a dict") + if not isinstance(data.get("graph"), dict): + raise ValueError("graph is not a dict") args = { "graph": data.get("graph"), @@ -129,11 +129,9 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {} - rag_pipeline_variables = { - k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] - for k, v in rag_pipeline_variables_list.items() - } + rag_pipeline_variables_list = args.get("rag_pipeline_variables") or [] + rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list] + rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -634,12 +632,15 @@ class RagPipelineSecondStepApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - datasource_provider = request.args.get("datasource_provider", required=True, type=str) + node_id = request.args.get("node_id", required=True, type=str) rag_pipeline_service = RagPipelineService() - return rag_pipeline_service.get_second_step_parameters( - pipeline=pipeline, datasource_provider=datasource_provider + variables = rag_pipeline_service.get_second_step_parameters( + pipeline=pipeline, node_id=node_id ) + return { + "variables": variables, + } class RagPipelineWorkflowRunListApi(Resource): @@ -785,3 +786,7 @@ api.add_resource( DatasourceListApi, "/rag/pipelines/datasource-plugins", ) +api.add_resource( + RagPipelineSecondStepApi, + "/rag/pipelines//workflows/processing/paramters", +) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 21bc39d440..1063e66c59 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,7 +4,6 @@ from typing import Any, Optional, TextIO, Union from pydantic import BaseModel from configs import dify_config -from core.datasource.entities.datasource_entities import DatasourceInvokeMessage from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage @@ -114,35 +113,6 @@ class DifyAgentCallbackHandler(BaseModel): color=self.color, ) - def on_datasource_end( - self, - datasource_name: str, - datasource_inputs: Mapping[str, Any], - datasource_outputs: Iterable[DatasourceInvokeMessage] | str, - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, - ) -> None: - """Run on datasource end.""" - if dify_config.DEBUG: - print_text("\n[on_datasource_end]\n", color=self.color) - print_text("Datasource: " + datasource_name + "\n", color=self.color) - print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color) - print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color) - print_text("\n") - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASOURCE_TRACE, - message_id=message_id, - datasource_name=datasource_name, - datasource_inputs=datasource_inputs, - datasource_outputs=datasource_outputs, - timer=timer, - ) - ) - @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index e1bcbc323b..25d7c1c352 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity): class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): - datasources: list[DatasourceEntity] = Field(default_factory=list) + datasources: list[DatasourceEntity] = Field(default_factory=list) class DatasourceInvokeMeta(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 05661a6cc8..6b2c91a8a0 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -127,7 +127,7 @@ class GeneralStructureChunk(BaseModel): General Structure Chunk. """ - general_chunk: list[str] + general_chunks: list[str] data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index e1db5db43d..002833d786 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -80,9 +80,9 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if not mapping.get("name"): - raise VariableError("missing name") - return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]]) + if not mapping.get("variable"): + raise VariableError("missing variable") + return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]]) def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: @@ -123,6 +123,43 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) +def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: + """ + This factory function is used to create the rag pipeline variable, + not support the File type. + """ + if (type := mapping.get("type")) is None: + raise VariableError("missing type") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any + match type: + case SegmentType.STRING: + result = StringVariable.model_validate(mapping) + case SegmentType.SECRET: + result = SecretVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, int): + result = IntegerVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, float): + result = FloatVariable.model_validate(mapping) + case SegmentType.NUMBER if not isinstance(value, float | int): + raise VariableError(f"invalid number value {value}") + case SegmentType.OBJECT if isinstance(value, dict): + result = ObjectVariable.model_validate(mapping) + case SegmentType.ARRAY_STRING if isinstance(value, list): + result = ArrayStringVariable.model_validate(mapping) + case SegmentType.ARRAY_NUMBER if isinstance(value, list): + result = ArrayNumberVariable.model_validate(mapping) + case SegmentType.ARRAY_OBJECT if isinstance(value, list): + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f"not supported type {type}") + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") + if not result.selector: + result = result.model_copy(update={"selector": selector}) + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: if value is None: diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index a37ae7856d..0733192c4f 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -42,9 +42,19 @@ conversation_variable_fields = { pipeline_variable_fields = { "id": fields.String, - "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), - "value": fields.Raw, + "label": fields.String, + "variable": fields.String, + "type": fields.String(attribute="type.value"), + "belong_to_node_id": fields.String, + "max_length": fields.Integer, + "required": fields.Boolean, + "default_value": fields.Raw, + "options": fields.List(fields.String), + "placeholder": fields.String, + "tooltips": fields.String, + "allowed_file_types": fields.List(fields.String), + "allow_file_extension": fields.List(fields.String), + "allow_file_upload_methods": fields.List(fields.String), } workflow_fields = { @@ -62,6 +72,7 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), + "rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)), } workflow_partial_fields = { diff --git a/api/models/workflow.py b/api/models/workflow.py index 5cb413b6a6..4ab59b26a6 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -352,21 +352,19 @@ class Workflow(Base): ) @property - def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]: + def rag_pipeline_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._rag_pipeline_variables is None: self._rag_pipeline_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = {} - for k, v in variables_dict.items(): - results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] + results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()] return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: + def rag_pipeline_variables(self, values: Sequence[Variable]) -> None: self._rag_pipeline_variables = json.dumps( - {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, + {item.name: item.model_dump() for item in values}, ensure_ascii=False, ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f380bc32d7..d2fc4d8100 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -201,7 +201,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], - rag_pipeline_variables: dict[str, Sequence[Variable]], + rag_pipeline_variables: Sequence[Variable], ) -> Workflow: """ Sync draft workflow @@ -552,7 +552,7 @@ class RagPipelineService: return workflow - def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: """ Get second step parameters of rag pipeline """ @@ -562,13 +562,15 @@ class RagPipelineService: raise ValueError("Workflow not initialized") # get second step node - pipeline_variables = workflow.pipeline_variables - if not pipeline_variables: + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: return {} + # get datasource provider - datasource_provider_variables = pipeline_variables.get(datasource_provider, []) - shared_variables = pipeline_variables.get("shared", []) - return datasource_provider_variables + shared_variables + datasource_provider_variables = [item for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id + or item.get("belong_to_node_id") == "shared"] + return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: """ From a64df507f60519068805401d0ff18e8cb83386be Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 20 May 2025 15:18:33 +0800 Subject: [PATCH 024/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 5 +-- api/factories/variable_factory.py | 40 +------------------ api/models/workflow.py | 9 +++-- api/services/rag_pipeline/rag_pipeline.py | 2 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 4 +- 5 files changed, 9 insertions(+), 51 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 26cd3dd90b..fa4130b762 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -129,9 +129,6 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_variables_list = args.get("rag_pipeline_variables") or [] - rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list] - rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -140,7 +137,7 @@ class DraftRagPipelineApi(Resource): account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, + rag_pipeline_variables=args.get("rag_pipeline_variables") or [], ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 002833d786..69a786e2f5 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -82,7 +82,7 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: if not mapping.get("variable"): raise VariableError("missing variable") - return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]]) + return mapping["variable"] def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: @@ -123,44 +123,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) -def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: - """ - This factory function is used to create the rag pipeline variable, - not support the File type. - """ - if (type := mapping.get("type")) is None: - raise VariableError("missing type") - if (value := mapping.get("value")) is None: - raise VariableError("missing value") - # FIXME: using Any here, fix it later - result: Any - match type: - case SegmentType.STRING: - result = StringVariable.model_validate(mapping) - case SegmentType.SECRET: - result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): - result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): - result = FloatVariable.model_validate(mapping) - case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f"invalid number value {value}") - case SegmentType.OBJECT if isinstance(value, dict): - result = ObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_STRING if isinstance(value, list): - result = ArrayStringVariable.model_validate(mapping) - case SegmentType.ARRAY_NUMBER if isinstance(value, list): - result = ArrayNumberVariable.model_validate(mapping) - case SegmentType.ARRAY_OBJECT if isinstance(value, list): - result = ArrayObjectVariable.model_validate(mapping) - case _: - raise VariableError(f"not supported type {type}") - if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") - if not result.selector: - result = result.model_copy(update={"selector": selector}) - return cast(Variable, result) - def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() diff --git a/api/models/workflow.py b/api/models/workflow.py index 4ab59b26a6..f04cafe3ed 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Self, Union +from typing import TYPE_CHECKING, Any, List, Optional, Self, Union from uuid import uuid4 if TYPE_CHECKING: @@ -331,6 +331,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], + "rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables], } return result @@ -358,13 +359,13 @@ class Workflow(Base): self._rag_pipeline_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()] + results = [v for v in variables_dict.values()] return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: Sequence[Variable]) -> None: + def rag_pipeline_variables(self, values: List[dict]) -> None: self._rag_pipeline_variables = json.dumps( - {item.name: item.model_dump() for item in values}, + {item["variable"]: item for item in values}, ensure_ascii=False, ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index d2fc4d8100..63b5c9983c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -201,7 +201,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], - rag_pipeline_variables: Sequence[Variable], + rag_pipeline_variables: list, ) -> Workflow: """ Sync draft workflow diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 1c6dac55be..19c7d37f6e 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -578,9 +578,6 @@ class RagPipelineDslService: variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) - rag_pipeline_variables = [ - variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list - ] rag_pipeline_service = RagPipelineService() current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -610,6 +607,7 @@ class RagPipelineDslService: account=account, environment_variables=environment_variables, conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables_list, ) return pipeline From 9bafd3a2261387ab492835345c1af41f642d8ab2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 20 May 2025 15:41:10 +0800 Subject: [PATCH 025/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 63b5c9983c..a7ad3109c3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -29,7 +29,7 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore -from models.enums import CreatedByRole, WorkflowRunTriggeredFrom +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -491,7 +491,7 @@ class RagPipelineService: workflow_node_execution.node_type = node_instance.node_type workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if run_succeeded and node_run_result: From 5fa2aca2c8d7f9edd00c32ad0240f9aae8ba818b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 21 May 2025 20:29:59 +0800 Subject: [PATCH 026/155] feat: add oauth schema to datasource --- api/core/tools/entities/tool_entities.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 37375f4a71..9884d93e9d 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -349,6 +350,7 @@ class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider") class ToolProviderEntityWithPlugin(ToolProviderEntity): From 3bfc602561d8dc4cafbcc6e9a3799be0496be282 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 21 May 2025 20:36:26 +0800 Subject: [PATCH 027/155] refactor: update datasource entity structure and parameter handling - Renamed and split parameters in DatasourceEntity into first_step_parameters and second_step_parameters. - Updated validation methods for new parameter structure. - Adjusted datasource_node to reference first_step_parameters. - Cleaned up unused imports and improved type hints in workflow.py. --- .../entities/datasource_entities.py | 19 +++++++++++++------ .../nodes/datasource/datasource_node.py | 2 +- api/factories/variable_factory.py | 2 +- api/models/workflow.py | 6 +++--- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 25d7c1c352..04e6915f31 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -22,8 +22,8 @@ class DatasourceProviderType(enum.StrEnum): """ ONLINE_DOCUMENT = "online_document" - LOCAL_FILE = "local_file" WEBSITE = "website" + ONLINE_DRIVE = "online_drive" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -125,14 +125,21 @@ class DatasourceDescription(BaseModel): class DatasourceEntity(BaseModel): identity: DatasourceIdentity - parameters: list[DatasourceParameter] = Field(default_factory=list) description: Optional[DatasourceDescription] = None - output_schema: Optional[dict] = None + first_step_parameters: list[DatasourceParameter] = Field(default_factory=list) + second_step_parameters: list[DatasourceParameter] = Field(default_factory=list) + first_step_output_schema: Optional[dict] = None + second_step_output_schema: Optional[dict] = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") - @field_validator("parameters", mode="before") + @field_validator("first_step_parameters", mode="before") @classmethod - def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: + def set_first_step_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: + return v or [] + + @field_validator("second_step_parameters", mode="before") + @classmethod + def set_second_step_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: return v or [] @@ -145,7 +152,7 @@ class DatasourceProviderEntity(ToolProviderEntity): class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): - datasources: list[DatasourceEntity] = Field(default_factory=list) + datasources: list[DatasourceEntity] = Field(default_factory=list) class DatasourceInvokeMeta(BaseModel): diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e7d4da8426..4e64c024c8 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -64,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return # get parameters - datasource_parameters = datasource_runtime.entity.parameters + datasource_parameters = datasource_runtime.entity.first_step_parameters parameters = self._generate_parameters( datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 69a786e2f5..d829d57812 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -39,7 +39,6 @@ from core.variables.variables import ( from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, - PIPELINE_VARIABLE_NODE_ID, ) @@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) + def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() diff --git a/api/models/workflow.py b/api/models/workflow.py index d5cf71841e..5cdb769209 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, List, Optional, Self, Union +from typing import TYPE_CHECKING, Any, Optional, Self, Union from uuid import uuid4 from core.variables import utils as variable_utils @@ -366,11 +366,11 @@ class Workflow(Base): self._rag_pipeline_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = [v for v in variables_dict.values()] + results = list(variables_dict.values()) return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: List[dict]) -> None: + def rag_pipeline_variables(self, values: list[dict]) -> None: self._rag_pipeline_variables = json.dumps( {item["variable"]: item for item in values}, ensure_ascii=False, From b82b26bba51709d49cbf9eab94a986d5be680340 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 23 May 2025 00:05:57 +0800 Subject: [PATCH 028/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 136 ++++- api/core/app/app_config/entities.py | 2 +- api/core/app/apps/pipeline/__init__.py | 0 .../pipeline/generate_response_converter.py | 95 ++++ .../apps/pipeline/pipeline_config_manager.py | 63 +++ .../app/apps/pipeline/pipeline_generator.py | 496 ++++++++++++++++++ .../apps/pipeline/pipeline_queue_manager.py | 44 ++ api/core/app/apps/pipeline/pipeline_runner.py | 154 ++++++ api/core/app/entities/app_invoke_entities.py | 35 ++ .../datasource/__base/datasource_plugin.py | 65 +-- .../datasource/__base/datasource_provider.py | 45 +- api/core/datasource/entities/api_entities.py | 6 +- .../entities/datasource_entities.py | 114 +++- .../local_file/local_file_plugin.py | 37 ++ .../local_file/local_file_provider.py | 58 ++ .../online_document/online_document_plugin.py | 80 +++ .../online_document_provider.py | 50 ++ .../website_crawl/website_crawl_plugin.py | 63 +++ .../website_crawl/website_crawl_provider.py | 50 ++ api/core/plugin/entities/plugin_daemon.py | 1 + api/core/plugin/impl/datasource.py | 153 +++++- api/core/workflow/enums.py | 4 + .../nodes/datasource/datasource_node.py | 47 +- .../nodes/knowledge_index/entities.py | 5 - .../knowledge_index/knowledge_index_node.py | 103 +--- .../nodes/knowledge_index/template_prompts.py | 66 --- .../nodes/knowledge_retrieval/entities.py | 2 - api/factories/variable_factory.py | 2 +- ...6_1659-abb18a379e62_add_pipeline_info_2.py | 113 ++++ api/models/dataset.py | 2 + api/models/model.py | 1 + api/models/workflow.py | 6 +- api/services/dataset_service.py | 2 +- .../rag_pipeline/pipeline_generate_service.py | 109 ++++ .../database/database_retrieval.py | 30 +- api/services/rag_pipeline/rag_pipeline.py | 75 ++- 36 files changed, 1983 insertions(+), 331 deletions(-) create mode 100644 api/core/app/apps/pipeline/__init__.py create mode 100644 api/core/app/apps/pipeline/generate_response_converter.py create mode 100644 api/core/app/apps/pipeline/pipeline_config_manager.py create mode 100644 api/core/app/apps/pipeline/pipeline_generator.py create mode 100644 api/core/app/apps/pipeline/pipeline_queue_manager.py create mode 100644 api/core/app/apps/pipeline/pipeline_runner.py create mode 100644 api/core/datasource/local_file/local_file_plugin.py create mode 100644 api/core/datasource/local_file/local_file_provider.py create mode 100644 api/core/datasource/online_document/online_document_plugin.py create mode 100644 api/core/datasource/online_document/online_document_provider.py create mode 100644 api/core/datasource/website_crawl/website_crawl_plugin.py create mode 100644 api/core/datasource/website_crawl/website_crawl_provider.py delete mode 100644 api/core/workflow/nodes/knowledge_index/template_prompts.py create mode 100644 api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py create mode 100644 api/services/rag_pipeline/pipeline_generate_service.py diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fa4130b762..c0406940a7 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -39,9 +39,9 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline -from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError +from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError @@ -170,7 +170,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): args = parser.parse_args() try: - response = AppGenerateService.generate_single_iteration( + response = PipelineGenerateService.generate_single_iteration( pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True ) @@ -207,7 +207,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): args = parser.parse_args() try: - response = AppGenerateService.generate_single_loop( + response = PipelineGenerateService.generate_single_loop( pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True ) @@ -241,11 +241,12 @@ class DraftRagPipelineRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info", type=list, required=True, location="json") args = parser.parse_args() try: - response = AppGenerateService.generate( + response = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, args=args, @@ -258,7 +259,73 @@ class DraftRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) +class PublishedRagPipelineRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Run published workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info", type=list, required=True, location="json") + args = parser.parse_args() + + try: + response = PipelineGenerateService.generate( + pipeline=pipeline, + user=current_user, + args=args, + invoke_from=InvokeFrom.PUBLISHED, + streaming=True, + ) + + return helper.compact_generate_response(response) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + + class RagPipelineDatasourceNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return result + + +class RagPipelinePublishedNodeRunApi(Resource): @setup_required @login_required @account_initialization_required @@ -283,7 +350,7 @@ class RagPipelineDatasourceNodeRunApi(Resource): raise ValueError("missing inputs") rag_pipeline_service = RagPipelineService() - workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node( + workflow_node_execution = rag_pipeline_service.run_published_workflow_node( pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user ) @@ -354,7 +421,8 @@ class PublishedRagPipelineApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if not pipeline.is_published: + return None # fetch published workflow by pipeline rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) @@ -397,10 +465,8 @@ class PublishedRagPipelineApi(Resource): marked_name=args.marked_name or "", marked_comment=args.marked_comment or "", ) - + pipeline.is_published = True pipeline.workflow_id = workflow.id - db.session.commit() - workflow_created_at = TimestampField().format(workflow.created_at) session.commit() @@ -617,7 +683,7 @@ class RagPipelineByIdApi(Resource): return None, 204 -class RagPipelineSecondStepApi(Resource): +class PublishedRagPipelineSecondStepApi(Resource): @setup_required @login_required @account_initialization_required @@ -632,9 +698,28 @@ class RagPipelineSecondStepApi(Resource): node_id = request.args.get("node_id", required=True, type=str) rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_second_step_parameters( - pipeline=pipeline, node_id=node_id - ) + variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) + return { + "variables": variables, + } + + +class DraftRagPipelineSecondStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get second step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + node_id = request.args.get("node_id", required=True, type=str) + + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) return { "variables": variables, } @@ -732,15 +817,21 @@ api.add_resource( RagPipelineDraftNodeRunApi, "/rag/pipelines//workflows/draft/nodes//run", ) -# api.add_resource( -# RagPipelinePublishedNodeRunApi, -# "/rag/pipelines//workflows/published/nodes//run", -# ) +api.add_resource( + RagPipelineDatasourceNodeRunApi, + "/rag/pipelines//workflows/datasource/nodes//run", +) api.add_resource( RagPipelineDraftRunIterationNodeApi, "/rag/pipelines//workflows/draft/iteration/nodes//run", ) + +api.add_resource( + RagPipelinePublishedNodeRunApi, + "/rag/pipelines//workflows/published/nodes//run", +) + api.add_resource( RagPipelineDraftRunLoopNodeApi, "/rag/pipelines//workflows/draft/loop/nodes//run", @@ -762,7 +853,6 @@ api.add_resource( DefaultRagPipelineBlockConfigApi, "/rag/pipelines//workflows/default-workflow-block-configs/", ) - api.add_resource( RagPipelineByIdApi, "/rag/pipelines//workflows/", @@ -784,6 +874,10 @@ api.add_resource( "/rag/pipelines/datasource-plugins", ) api.add_resource( - RagPipelineSecondStepApi, - "/rag/pipelines//workflows/processing/paramters", + PublishedRagPipelineSecondStepApi, + "/rag/pipelines//workflows/published/processing/paramters", +) +api.add_resource( + DraftRagPipelineSecondStepApi, + "/rag/pipelines//workflows/draft/processing/paramters", ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 8ae52131f2..48e8ca5594 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -283,7 +283,7 @@ class AppConfig(BaseModel): tenant_id: str app_id: str app_mode: AppMode - additional_features: AppAdditionalFeatures + additional_features: Optional[AppAdditionalFeatures] = None variables: list[VariableEntity] = [] sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/app/apps/pipeline/__init__.py b/api/core/app/apps/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py new file mode 100644 index 0000000000..10ec73a7d2 --- /dev/null +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -0,0 +1,95 @@ +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) + + +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = WorkflowAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + return dict(blocking_response.to_dict()) + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + return cls.convert_blocking_full_response(blocking_response) + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield response_chunk + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield response_chunk diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py new file mode 100644 index 0000000000..ddf87eacbb --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -0,0 +1,63 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.dataset import Pipeline +from models.model import AppMode +from models.workflow import Workflow + + +class PipelineConfig(WorkflowUIBasedAppConfig): + """ + Pipeline Config Entity. + """ + + pass + + +class PipelineConfigManager(BaseAppConfigManager): + @classmethod + def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig: + pipeline_config = PipelineConfig( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + app_mode=AppMode.RAG_PIPELINE, + workflow_id=workflow.id, + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + ) + + return pipeline_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for pipeline config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py new file mode 100644 index 0000000000..1e880c700c --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -0,0 +1,496 @@ +import contextvars +import datetime +import json +import logging +import random +import threading +import time +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Optional, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError +from sqlalchemy.orm import sessionmaker + +import contexts +from configs import dify_config +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager +from core.app.apps.pipeline.pipeline_runner import PipelineRunner +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from extensions.ext_database import db +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.dataset import Document, Pipeline +from services.dataset_service import DocumentService + +logger = logging.getLogger(__name__) + + +class PipelineGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Generator[Mapping | str, None, None]: ... + + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... + + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + # convert to app config + pipeline_config = PipelineConfigManager.get_pipeline_config( + pipeline=pipeline, + workflow=workflow, + ) + + inputs: Mapping[str, Any] = args["inputs"] + datasource_type: str = args["datasource_type"] + datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + + for datasource_info in datasource_info_list: + workflow_run_id = str(uuid.uuid4()) + document_id = None + if invoke_from == InvokeFrom.PUBLISHED: + position = DocumentService.get_documents_position(pipeline.dataset_id) + document = self._build_document( + tenant_id=pipeline.tenant_id, + dataset_id=pipeline.dataset_id, + built_in_field_enabled=pipeline.dataset.built_in_field_enabled, + datasource_type=datasource_type, + datasource_info=datasource_info, + created_from="rag-pipeline", + position=position, + account=user, + batch=batch, + document_form=pipeline.dataset.doc_form, + ) + db.session.add(document) + db.session.commit() + document_id = document.id + # init application generate entity + application_generate_entity = RagPipelineGenerateEntity( + task_id=str(uuid.uuid4()), + pipline_config=pipeline_config, + datasource_type=datasource_type, + datasource_info=datasource_info, + dataset_id=pipeline.dataset_id, + batch=batch, + document_id=document_id, + inputs=self._prepare_user_inputs( + user_inputs=inputs, + variables=pipeline_config.variables, + tenant_id=pipeline.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ), + files=[], + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + workflow_run_id=workflow_run_id, + ) + + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + pipeline=pipeline, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + def _generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: RagPipelineGenerateEntity, + invoke_from: InvokeFrom, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + streaming: bool = True, + workflow_thread_pool_id: Optional[str] = None, + ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param application_generate_entity: application generate entity + :param invoke_from: invoke from source + :param workflow_node_execution_repository: repository for workflow node execution + :param streaming: is stream + :param workflow_thread_pool_id: workflow thread pool id + """ + # init queue manager + queue_manager = PipelineQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=pipeline.mode, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + workflow_node_execution_repository=workflow_node_execution_repository, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def single_iteration_generate( + self, + app_model: App, + workflow: Workflow, + node_id: str, + user: Account | EndUser, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param node_id: the node id + :param user: account or end user + :param args: request args + :param streaming: is streamed + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, inputs=args["inputs"] + ), + workflow_run_id=str(uuid.uuid4()), + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + ) + + def single_loop_generate( + self, + app_model: App, + workflow: Workflow, + node_id: str, + user: Account | EndUser, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param node_id: the node id + :param user: account or end user + :param args: request args + :param streaming: is streamed + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + workflow_run_id=str(uuid.uuid4()), + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + ) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: RagPipelineGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id + :return: + """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): + try: + # workflow app + runner = PipelineRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + runner.run() + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def _handle_response( + self, + application_generate_entity: RagPipelineGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param user: account or end user + :param stream: is stream + :param workflow_node_execution_repository: optional repository for workflow node execution + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=stream, + workflow_node_execution_repository=workflow_node_execution_repository, + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedError() + else: + logger.exception( + f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" + ) + raise e + + def _build_document( + self, + tenant_id: str, + dataset_id: str, + built_in_field_enabled: bool, + datasource_type: str, + datasource_info: Mapping[str, Any], + created_from: str, + position: int, + account: Account, + batch: str, + document_form: str, + ): + if datasource_type == "local_file": + name = datasource_info["name"] + elif datasource_type == "online_document": + name = datasource_info["page_title"] + elif datasource_type == "website_crawl": + name = datasource_info["title"] + else: + raise ValueError(f"Unsupported datasource type: {datasource_type}") + + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=position, + data_source_type=datasource_type, + data_source_info=json.dumps(datasource_info), + batch=batch, + name=name, + created_from=created_from, + created_by=account.id, + doc_form=document_form, + ) + doc_metadata = {} + if built_in_field_enabled: + doc_metadata = { + BuiltInField.document_name: name, + BuiltInField.uploader: account.name, + BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.source: datasource_type, + } + if doc_metadata: + document.doc_metadata = doc_metadata + return document diff --git a/api/core/app/apps/pipeline/pipeline_queue_manager.py b/api/core/app/apps/pipeline/pipeline_queue_manager.py new file mode 100644 index 0000000000..d0aeac8a9c --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_queue_manager.py @@ -0,0 +1,44 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, + WorkflowQueueMessage, +) + + +class PipelineQueueManager(AppQueueManager): + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) + + self._q.put(message) + + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, + ): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py new file mode 100644 index 0000000000..1395a47d88 --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -0,0 +1,154 @@ +import logging +from typing import Optional, cast + +from configs import dify_config +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + RagPipelineGenerateEntity, +) +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.dataset import Pipeline +from models.enums import UserFrom +from models.model import EndUser +from models.workflow import Workflow, WorkflowType + +logger = logging.getLogger(__name__) + + +class PipelineRunner(WorkflowBasedAppRunner): + """ + Pipeline Application Runner + """ + + def __init__( + self, + application_generate_entity: RagPipelineGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: + """ + Run application + """ + app_config = self.application_generate_entity.app_config + app_config = cast(PipelineConfig, app_config) + + user_id = None + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id + + pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + + workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + db.session.close() + + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: + workflow_callbacks.append(WorkflowLoggingCallback()) + + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + elif self.application_generate_entity.single_loop_run: + # if only single loop run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( + workflow=workflow, + node_id=self.application_generate_entity.single_loop_run.node_id, + user_inputs=self.application_generate_entity.single_loop_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, + SystemVariableKey.BATCH: self.application_generate_entity.batch, + SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) + + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id, + ) + + generator = workflow_entry.run(callbacks=workflow_callbacks) + + for event in generator: + self._handle_event(workflow_entry, event) + + def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 56e6b46a60..d730704f48 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -21,6 +21,7 @@ class InvokeFrom(Enum): WEB_APP = "web-app" EXPLORE = "explore" DEBUGGER = "debugger" + PUBLISHED = "published" @classmethod def value_of(cls, value: str): @@ -226,3 +227,37 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): inputs: dict single_loop_run: Optional[SingleLoopRunEntity] = None + + +class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): + """ + RAG Pipeline Application Generate Entity. + """ + + # app config + pipline_config: WorkflowUIBasedAppConfig + datasource_type: str + datasource_info: Mapping[str, Any] + dataset_id: str + batch: str + document_id: str + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None + + class SingleLoopRunEntity(BaseModel): + """ + Single Loop Run Entity. + """ + + node_id: str + inputs: dict + + single_loop_run: Optional[SingleLoopRunEntity] = None diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 15d9e7d9ba..d8681b6491 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,18 +1,13 @@ -from collections.abc import Mapping -from typing import Any +from abc import ABC, abstractmethod from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceProviderType, ) -from core.plugin.impl.datasource import PluginDatasourceManager -from core.plugin.utils.converter import convert_parameters_to_plugin_format -class DatasourcePlugin: - tenant_id: str - icon: str - plugin_unique_identifier: str +class DatasourcePlugin(ABC): entity: DatasourceEntity runtime: DatasourceRuntime @@ -20,57 +15,19 @@ class DatasourcePlugin: self, entity: DatasourceEntity, runtime: DatasourceRuntime, - tenant_id: str, - icon: str, - plugin_unique_identifier: str, ) -> None: self.entity = entity self.runtime = runtime - self.tenant_id = tenant_id - self.icon = icon - self.plugin_unique_identifier = plugin_unique_identifier - def _invoke_first_step( - self, - user_id: str, - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: - manager = PluginDatasourceManager() - - datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - - return manager.invoke_first_step( - tenant_id=self.tenant_id, - user_id=user_id, - datasource_provider=self.entity.identity.provider, - datasource_name=self.entity.identity.name, - credentials=self.runtime.credentials, - datasource_parameters=datasource_parameters, - ) - - def _invoke_second_step( - self, - user_id: str, - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: - manager = PluginDatasourceManager() - - datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - - return manager.invoke_second_step( - tenant_id=self.tenant_id, - user_id=user_id, - datasource_provider=self.entity.identity.provider, - datasource_name=self.entity.identity.name, - credentials=self.runtime.credentials, - datasource_parameters=datasource_parameters, - ) + @abstractmethod + def datasource_provider_type(self) -> DatasourceProviderType: + """ + returns the type of the datasource provider + """ + return DatasourceProviderType.LOCAL_FILE def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, + return self.__class__( + entity=self.entity.model_copy(), runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, ) diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index 13804f53d9..1544270d7a 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -1,26 +1,19 @@ +from abc import ABC, abstractmethod from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.entities.provider_entities import ProviderConfig from core.plugin.impl.tool import PluginToolManager from core.tools.errors import ToolProviderCredentialValidationError -class DatasourcePluginProviderController: +class DatasourcePluginProviderController(ABC): entity: DatasourceProviderEntityWithPlugin - tenant_id: str - plugin_id: str - plugin_unique_identifier: str - def __init__( - self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str - ) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None: self.entity = entity - self.tenant_id = tenant_id - self.plugin_id = plugin_id - self.plugin_unique_identifier = plugin_unique_identifier @property def need_credentials(self) -> bool: @@ -44,29 +37,19 @@ class DatasourcePluginProviderController: ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.LOCAL_FILE + + @abstractmethod + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: """ return datasource with given name """ - datasource_entity = next( - ( - datasource_entity - for datasource_entity in self.entity.datasources - if datasource_entity.identity.name == datasource_name - ), - None, - ) - - if not datasource_entity: - raise ValueError(f"Datasource with name {datasource_name} not found") - - return DatasourcePlugin( - entity=datasource_entity, - runtime=DatasourceRuntime(tenant_id=self.tenant_id), - tenant_id=self.tenant_id, - icon=self.entity.identity.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) + pass def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore """ diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 8d6bed41fa..3b224c9e64 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -28,13 +28,13 @@ class DatasourceProviderApiEntity(BaseModel): description: I18nObject icon: str | dict label: I18nObject # label - type: ToolProviderType + type: str masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") datasources: list[DatasourceApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 25d7c1c352..7b3fadfee8 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -23,7 +23,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DOCUMENT = "online_document" LOCAL_FILE = "local_file" - WEBSITE = "website" + WEBSITE_CRAWL = "website_crawl" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -111,10 +111,10 @@ class DatasourceParameter(PluginParameter): class DatasourceIdentity(BaseModel): - author: str = Field(..., description="The author of the tool") - name: str = Field(..., description="The name of the tool") - label: I18nObject = Field(..., description="The label of the tool") - provider: str = Field(..., description="The provider of the tool") + author: str = Field(..., description="The author of the datasource") + name: str = Field(..., description="The name of the datasource") + label: I18nObject = Field(..., description="The label of the datasource") + provider: str = Field(..., description="The provider of the datasource") icon: Optional[str] = None @@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity): class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): - datasources: list[DatasourceEntity] = Field(default_factory=list) + datasources: list[DatasourceEntity] = Field(default_factory=list) class DatasourceInvokeMeta(BaseModel): @@ -195,3 +195,105 @@ class DatasourceInvokeFrom(Enum): """ RAG_PIPELINE = "rag_pipeline" + + +class GetOnlineDocumentPagesRequest(BaseModel): + """ + Get online document pages request + """ + + tenant_id: str = Field(..., description="The tenant id") + + +class OnlineDocumentPageIcon(BaseModel): + """ + Online document page icon + """ + + type: str = Field(..., description="The type of the icon") + url: str = Field(..., description="The url of the icon") + + +class OnlineDocumentPage(BaseModel): + """ + Online document page + """ + + page_id: str = Field(..., description="The page id") + page_title: str = Field(..., description="The page title") + page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon") + type: str = Field(..., description="The type of the page") + last_edited_time: str = Field(..., description="The last edited time") + + +class OnlineDocumentInfo(BaseModel): + """ + Online document info + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") + + +class GetOnlineDocumentPagesResponse(BaseModel): + """ + Get online document pages response + """ + + result: list[OnlineDocumentInfo] + + +class GetOnlineDocumentPageContentRequest(BaseModel): + """ + Get online document page content request + """ + + online_document_info_list: list[OnlineDocumentInfo] + + +class OnlineDocumentPageContent(BaseModel): + """ + Online document page content + """ + + page_id: str = Field(..., description="The page id") + content: str = Field(..., description="The content of the page") + + +class GetOnlineDocumentPageContentResponse(BaseModel): + """ + Get online document page content response + """ + + result: list[OnlineDocumentPageContent] + + +class GetWebsiteCrawlRequest(BaseModel): + """ + Get website crawl request + """ + + url: str = Field(..., description="The url of the website") + crawl_parameters: dict = Field(..., description="The crawl parameters") + + +class WebSiteInfo(BaseModel): + """ + Website info + """ + + source_url: str = Field(..., description="The url of the website") + markdown: str = Field(..., description="The markdown of the website") + title: str = Field(..., description="The title of the website") + description: str = Field(..., description="The description of the website") + + +class GetWebsiteCrawlResponse(BaseModel): + """ + Get website crawl response + """ + + result: list[WebSiteInfo] diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py new file mode 100644 index 0000000000..a9dced1186 --- /dev/null +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -0,0 +1,37 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) + + +class LocalFileDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.LOCAL_FILE + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py new file mode 100644 index 0000000000..79f885dda5 --- /dev/null +++ b/api/core/datasource/local_file/local_file_provider.py @@ -0,0 +1,58 @@ +from typing import Any + +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin + + +class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity) + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.LOCAL_FILE + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + pass + + def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return LocalFileDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py new file mode 100644 index 0000000000..197d85ef59 --- /dev/null +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -0,0 +1,80 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetOnlineDocumentPagesRequest, + GetOnlineDocumentPagesResponse, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class OnlineDocumentDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def _get_online_document_pages( + self, + user_id: str, + datasource_parameters: GetOnlineDocumentPagesRequest, + provider_type: str, + ) -> GetOnlineDocumentPagesResponse: + manager = PluginDatasourceManager() + + return manager.get_online_document_pages( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def _get_online_document_page_content( + self, + user_id: str, + datasource_parameters: GetOnlineDocumentPageContentRequest, + provider_type: str, + ) -> GetOnlineDocumentPageContentResponse: + manager = PluginDatasourceManager() + + return manager.get_online_document_page_content( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.ONLINE_DOCUMENT + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py new file mode 100644 index 0000000000..06572880b8 --- /dev/null +++ b/api/core/datasource/online_document/online_document_provider.py @@ -0,0 +1,50 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType + + +class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity) + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.ONLINE_DOCUMENT + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return DatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py new file mode 100644 index 0000000000..8454d1636e --- /dev/null +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -0,0 +1,63 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + GetWebsiteCrawlRequest, + GetWebsiteCrawlResponse, +) +from core.plugin.impl.datasource import PluginDatasourceManager +from core.plugin.utils.converter import convert_parameters_to_plugin_format + + +class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def _get_website_crawl( + self, + user_id: str, + datasource_parameters: GetWebsiteCrawlRequest, + provider_type: str, + ) -> GetWebsiteCrawlResponse: + manager = PluginDatasourceManager() + + datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) + + return manager.invoke_first_step( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.WEBSITE_CRAWL + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py new file mode 100644 index 0000000000..9c6bcdb7c2 --- /dev/null +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -0,0 +1,50 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType + + +class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity) + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.WEBSITE_CRAWL + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return DatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 90086173fa..3b0defbb08 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel): provider: str plugin_unique_identifier: str plugin_id: str + author: str declaration: DatasourceProviderEntityWithPlugin diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 922e65d725..ebe08bd7eb 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,6 +1,14 @@ -from collections.abc import Mapping from typing import Any +from core.datasource.entities.api_entities import DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import ( + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetOnlineDocumentPagesRequest, + GetOnlineDocumentPagesResponse, + GetWebsiteCrawlRequest, + GetWebsiteCrawlResponse, +) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -10,7 +18,7 @@ from core.plugin.impl.base import BasePluginClient class PluginDatasourceManager(BasePluginClient): - def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: + def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]: """ Fetch datasource providers for the given tenant. """ @@ -19,27 +27,27 @@ class PluginDatasourceManager(BasePluginClient): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} provider_name = declaration.get("identity", {}).get("name") - for tool in declaration.get("tools", []): - tool["identity"]["provider"] = provider_name + for datasource in declaration.get("datasources", []): + datasource["identity"]["provider"] = provider_name return json_response - response = self._request_with_plugin_daemon_response( - "GET", - f"plugin/{tenant_id}/management/datasources", - list[PluginDatasourceProviderEntity], - params={"page": 1, "page_size": 256}, - transformer=transformer, - ) + # response = self._request_with_plugin_daemon_response( + # "GET", + # f"plugin/{tenant_id}/management/datasources", + # list[PluginDatasourceProviderEntity], + # params={"page": 1, "page_size": 256}, + # transformer=transformer, + # ) - for provider in response: - provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + # for provider in response: + # provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name - for datasource in provider.declaration.datasources: - datasource.identity.provider = provider.declaration.identity.name + # # override the provider name for each tool to plugin_id/provider_name + # for datasource in provider.declaration.datasources: + # datasource.identity.provider = provider.declaration.identity.name - return response + return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())] def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ @@ -71,15 +79,16 @@ class PluginDatasourceManager(BasePluginClient): return response - def invoke_first_step( + def get_website_crawl( self, tenant_id: str, user_id: str, datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: + datasource_parameters: GetWebsiteCrawlRequest, + provider_type: str, + ) -> GetWebsiteCrawlResponse: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -88,8 +97,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/first_step", - dict, + f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_website_crawl", + GetWebsiteCrawlResponse, data={ "user_id": user_id, "data": { @@ -109,15 +118,16 @@ class PluginDatasourceManager(BasePluginClient): raise Exception("No response from plugin daemon") - def invoke_second_step( + def get_online_document_pages( self, tenant_id: str, user_id: str, datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: + datasource_parameters: GetOnlineDocumentPagesRequest, + provider_type: str, + ) -> GetOnlineDocumentPagesResponse: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -126,8 +136,47 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/second_step", - dict, + f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_pages", + GetOnlineDocumentPagesResponse, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "datasource_parameters": datasource_parameters, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise Exception("No response from plugin daemon") + + def get_online_document_page_content( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: GetOnlineDocumentPageContentRequest, + provider_type: str, + ) -> GetOnlineDocumentPageContentResponse: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_page_content", + GetOnlineDocumentPageContentResponse, data={ "user_id": user_id, "data": { @@ -176,3 +225,53 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False + + def _get_local_file_datasource_provider(self) -> dict[str, Any]: + return { + "id": "langgenius/file/file", + "author": "langgenius", + "name": "langgenius/file/file", + "plugin_id": "langgenius/file", + "plugin_unique_identifier": "langgenius/file:0.0.1@dify", + "description": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", + "label": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "type": "datasource", + "team_credentials": {}, + "is_team_authorization": False, + "allow_delete": True, + "datasources": [{ + "author": "langgenius", + "name": "upload_file", + "label": { + "en_US": "File", + "zh_Hans": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "description": { + "en_US": "File", + "zh_Hans": "File", + "pt_BR": "File", + "ja_JP": "File." + }, + "parameters": [], + "labels": [ + "search" + ], + "output_schema": None + }], + "labels": [ + "search" + ] + } diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 9642efa1a5..34d17c880a 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -14,3 +14,7 @@ class SystemVariableKey(StrEnum): APP_ID = "app_id" WORKFLOW_ID = "workflow_id" WORKFLOW_RUN_ID = "workflow_run_id" + # RAG Pipeline + DOCUMENT_ID = "document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e7d4da8426..d25784b781 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -3,7 +3,11 @@ from typing import Any, cast from core.datasource.entities.datasource_entities import ( DatasourceParameter, + DatasourceProviderType, + GetWebsiteCrawlResponse, ) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.file import File from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment @@ -77,15 +81,44 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): for_log=True, ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - try: # TODO: handle result - result = datasource_runtime._invoke_second_step( - user_id=self.user_id, - datasource_parameters=parameters, - ) + if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + result = datasource_runtime._get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=parameters, + provider_type=node_data.provider_type, + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "result": result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) + elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=self.user_id, + datasource_parameters=parameters, + provider_type=node_data.provider_type, + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "result": result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) + else: + raise DatasourceNodeError( + f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" + ) except PluginDaemonClientSideError as e: yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 6b2c91a8a0..0d0da757d5 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -155,9 +155,4 @@ class KnowledgeIndexNodeData(BaseNodeData): """ type: str = "knowledge-index" - dataset_id: str - document_id: str index_chunk_variable_selector: list[str] - chunk_structure: Literal["general", "parent-child"] - index_method: IndexMethod - retrieval_setting: RetrievalSetting diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 1fa6c20bf9..dac541621a 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,25 +1,19 @@ import datetime import logging -import time from collections.abc import Mapping from typing import Any, cast -from flask_login import current_user - -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import Dataset, Document, RateLimitLog +from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus -from services.dataset_service import DatasetCollectionBindingService -from services.feature_service import FeatureService from .entities import KnowledgeIndexNodeData from .exc import ( @@ -43,8 +37,9 @@ class KnowledgeIndexNode(LLMNode): def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeIndexNodeData, self.node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables - variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector) + variable = variable_pool.get(node_data.index_chunk_variable_selector) if not isinstance(variable, ObjectSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -57,34 +52,9 @@ class KnowledgeIndexNode(LLMNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) - # check rate limit - if self.tenant_id: - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) - if knowledge_rate_limit.enabled: - current_time = int(time.time() * 1000) - key = f"rate_limit_{self.tenant_id}" - redis_client.zadd(key, {current_time: current_time}) - redis_client.zremrangebyscore(key, 0, current_time - 60000) - request_count = redis_client.zcard(key) - if request_count > knowledge_rate_limit.limit: - # add ratelimit record - rate_limit_log = RateLimitLog( - tenant_id=self.tenant_id, - subscription_plan=knowledge_rate_limit.subscription_plan, - operation="knowledge", - ) - db.session.add(rate_limit_log) - db.session.commit() - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Sorry, you have reached the knowledge base request rate limit of your subscription.", - error_type="RateLimitExceeded", - ) - # retrieve knowledge try: - results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks) + results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool) outputs = {"result": results} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs @@ -107,54 +77,26 @@ class KnowledgeIndexNode(LLMNode): error_type=type(e).__name__, ) - def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any: - dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() + def _invoke_knowledge_index( + self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool + ) -> Any: + dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + if not dataset_id: + raise KnowledgeIndexNodeError("Dataset ID is required.") + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if not document_id: + raise KnowledgeIndexNodeError("Document ID is required.") + batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + if not batch: + raise KnowledgeIndexNodeError("Batch is required.") + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") + raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") - document = Document.query.filter_by(id=node_data.document_id).first() + document = Document.query.filter_by(id=document_id).first() if not document: - raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") + raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - retrieval_setting = node_data.retrieval_setting - index_method = node_data.index_method - if not dataset.indexing_technique: - if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: - raise ValueError("Indexing technique is invalid") - - dataset.indexing_technique = index_method.indexing_technique - if index_method.indexing_technique == "high_quality": - model_manager = ModelManager() - if ( - index_method.embedding_setting.embedding_model - and index_method.embedding_setting.embedding_model_provider - ): - dataset_embedding_model = index_method.embedding_setting.embedding_model - dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model - ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - - dataset.retrieval_model = ( - retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model - ) # type: ignore index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) @@ -166,6 +108,7 @@ class KnowledgeIndexNode(LLMNode): return { "dataset_id": dataset.id, "dataset_name": dataset.name, + "batch": batch, "document_id": document.id, "document_name": document.name, "created_at": document.created_at, diff --git a/api/core/workflow/nodes/knowledge_index/template_prompts.py b/api/core/workflow/nodes/knowledge_index/template_prompts.py deleted file mode 100644 index 7abd55d798..0000000000 --- a/api/core/workflow/nodes/knowledge_index/template_prompts.py +++ /dev/null @@ -1,66 +0,0 @@ -METADATA_FILTER_SYSTEM_PROMPT = """ - ### Job Description', - You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value - ### Task - Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". - ### Format - The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. - ### Constraint - DO NOT include anything other than the JSON array in your response. -""" # noqa: E501 - -METADATA_FILTER_USER_PROMPT_1 = """ - { "input_text": "I want to know which company’s email address test@example.com is?", - "metadata_fields": ["filename", "email", "phone", "address"] - } -""" - -METADATA_FILTER_ASSISTANT_PROMPT_1 = """ -```json - {"metadata_map": [ - {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} - ] - } -``` -""" - -METADATA_FILTER_USER_PROMPT_2 = """ - {"input_text": "What are the movies with a score of more than 9 in 2024?", - "metadata_fields": ["name", "year", "rating", "country"]} -""" - -METADATA_FILTER_ASSISTANT_PROMPT_2 = """ -```json - {"metadata_map": [ - {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, - {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, - ]} -``` -""" - -METADATA_FILTER_USER_PROMPT_3 = """ - '{{"input_text": "{input_text}",', - '"metadata_fields": {metadata_fields}}}' -""" - -METADATA_FILTER_COMPLETION_PROMPT = """ -### Job Description -You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value -### Task -# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". -### Format -The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} -Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} -User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} -Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} - -### User Input -{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} -### Assistant Output -""" # noqa: E501 diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8c702b74ee..00448d2a9b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -57,8 +57,6 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): - - provider: str name: str mode: str diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 69a786e2f5..d829d57812 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -39,7 +39,6 @@ from core.variables.variables import ( from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, - PIPELINE_VARIABLE_NODE_ID, ) @@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) + def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() diff --git a/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py new file mode 100644 index 0000000000..18e90e49dc --- /dev/null +++ b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py @@ -0,0 +1,113 @@ +"""add_pipeline_info_2 + +Revision ID: abb18a379e62 +Revises: b35c3db83d09 +Create Date: 2025-05-16 16:59:16.423127 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'abb18a379e62' +down_revision = 'b35c3db83d09' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('component_failure_stats') + op.drop_table('reliability_data') + op.drop_table('maintenance') + op.drop_table('operational_data') + op.drop_table('component_failure') + op.drop_table('tool_providers') + op.drop_table('safety_data') + op.drop_table('incident_data') + with op.batch_alter_table('pipelines', schema=None) as batch_op: + batch_op.drop_column('mode') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipelines', schema=None) as batch_op: + batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False)) + + op.create_table('incident_data', + sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False), + sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey') + ) + op.create_table('safety_data', + sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey') + ) + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + op.create_table('component_failure', + sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'), + sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry') + ) + op.create_table('operational_data', + sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey') + ) + op.create_table('maintenance', + sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey') + ) + op.create_table('reliability_data', + sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey') + ) + op.create_table('component_failure_stats', + sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey') + ) + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 0ed59c898f..22703771d5 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1170,6 +1170,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] def pipeline(self): return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" __table_args__ = ( @@ -1205,6 +1206,7 @@ class Pipeline(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property def dataset(self): return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() diff --git a/api/models/model.py b/api/models/model.py index ee79fbd6b5..e088c2e537 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -52,6 +52,7 @@ class AppMode(StrEnum): ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" CHANNEL = "channel" + RAG_PIPELINE = "rag-pipeline" @classmethod def value_of(cls, value: str) -> "AppMode": diff --git a/api/models/workflow.py b/api/models/workflow.py index d5cf71841e..038648fc8e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, List, Optional, Self, Union +from typing import TYPE_CHECKING, Any, Optional, Self, Union from uuid import uuid4 from core.variables import utils as variable_utils @@ -43,7 +43,7 @@ class WorkflowType(Enum): WORKFLOW = "workflow" CHAT = "chat" - RAG_PIPELINE = "rag_pipeline" + RAG_PIPELINE = "rag-pipeline" @classmethod def value_of(cls, value: str) -> "WorkflowType": @@ -370,7 +370,7 @@ class Workflow(Base): return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: List[dict]) -> None: + def rag_pipeline_variables(self, values: list[dict]) -> None: self._rag_pipeline_variables = json.dumps( {item["variable"]: item for item in values}, ensure_ascii=False, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 81db03033f..8a87964276 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1550,7 +1550,7 @@ class DocumentService: @staticmethod def build_document( dataset: Dataset, - process_rule_id: str, + process_rule_id: str | None, data_source_type: str, document_form: str, document_language: str, diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py new file mode 100644 index 0000000000..089519dd0d --- /dev/null +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -0,0 +1,109 @@ +from collections.abc import Mapping +from typing import Any, Union + +from configs import dify_config +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from models.dataset import Pipeline +from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class PipelineGenerateService: + @classmethod + def generate( + cls, + pipeline: Pipeline, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Pipeline Content Generate + :param pipeline: pipeline + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + try: + workflow = cls._get_workflow(pipeline, invoke_from) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().generate( + pipeline=pipeline, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + workflow_thread_pool_id=None, + ), + ) + + except Exception: + raise + + @staticmethod + def _get_max_active_requests(app_model: App) -> int: + max_active_requests = app_model.max_active_requests + if max_active_requests is None: + max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) + return max_active_requests + + @classmethod + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): + if app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + else: + raise ValueError(f"Invalid app mode {app_model.mode}") + + @classmethod + def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): + workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) + return WorkflowAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_loop_generate( + app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + + @classmethod + def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow: + """ + Get workflow + :param pipeline: pipeline + :param invoke_from: invoke from + :return: + """ + rag_pipeline_service = RagPipelineService() + if invoke_from == InvokeFrom.DEBUGGER: + # fetch draft workflow by app_model + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + if not workflow: + raise ValueError("Workflow not initialized") + else: + # fetch published workflow by app_model + workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) + + if not workflow: + raise ValueError("Workflow not published") + + return workflow diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index bda29c804c..11071d82e7 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -29,32 +29,31 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - - pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter( - PipelineBuiltInTemplate.language == language - ).all() + + pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( + db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() + ) recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: pipeline_model: Pipeline = pipeline_built_in_template.pipeline recommended_pipeline_result = { - 'id': pipeline_built_in_template.id, - 'name': pipeline_built_in_template.name, - 'pipeline_id': pipeline_model.id, - 'description': pipeline_built_in_template.description, - 'icon': pipeline_built_in_template.icon, - 'copyright': pipeline_built_in_template.copyright, - 'privacy_policy': pipeline_built_in_template.privacy_policy, - 'position': pipeline_built_in_template.position, + "id": pipeline_built_in_template.id, + "name": pipeline_built_in_template.name, + "pipeline_id": pipeline_model.id, + "description": pipeline_built_in_template.description, + "icon": pipeline_built_in_template.icon, + "copyright": pipeline_built_in_template.copyright, + "privacy_policy": pipeline_built_in_template.privacy_policy, + "position": pipeline_built_in_template.position, } dataset: Dataset = pipeline_model.dataset if dataset: - recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure + recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure recommended_pipelines_results.append(recommended_pipeline_result) - return {'pipeline_templates': recommended_pipelines_results} - + return {"pipeline_templates": recommended_pipelines_results} @classmethod def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: @@ -64,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + # is in public recommended list pipeline_template = ( db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a7ad3109c3..a0a890aee7 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -3,7 +3,7 @@ import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from typing import Any, Literal, Optional +from typing import Any, Optional from uuid import uuid4 from flask_login import current_user @@ -46,7 +46,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi class RagPipelineService: @staticmethod def get_pipeline_templates( - type: Literal["built-in", "customized"] = "built-in", language: str = "en-US" + type: str = "built-in", language: str = "en-US" ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE @@ -358,11 +358,11 @@ class RagPipelineService: return workflow_node_execution - def run_datasource_workflow_node( + def run_published_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account ) -> WorkflowNodeExecution: """ - Run published workflow datasource + Run published workflow node """ # fetch published workflow by app_model published_workflow = self.get_published_workflow(pipeline=pipeline) @@ -393,6 +393,41 @@ class RagPipelineService: return workflow_node_execution + def run_datasource_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run published workflow datasource + """ + # fetch published workflow by app_model + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + if not datasource_node_data: + raise ValueError("Datasource node data not found") + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=datasource_node_data.get("provider_id"), + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + ) + result = datasource_runtime._invoke_first_step( + inputs=user_inputs, + provider_type=datasource_node_data.get("provider_type"), + user_id=account.id, + ) + + return { + "result": result, + "provider_type": datasource_node_data.get("provider_type"), + } + def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -552,7 +587,7 @@ class RagPipelineService: return workflow - def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: """ Get second step parameters of rag pipeline """ @@ -567,9 +602,33 @@ class RagPipelineService: return {} # get datasource provider - datasource_provider_variables = [item for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id - or item.get("belong_to_node_id") == "shared"] + datasource_provider_variables = [ + item + for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] + return datasource_provider_variables + + def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + """ + Get second step parameters of rag pipeline + """ + + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # get second step node + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: + return {} + + # get datasource provider + datasource_provider_variables = [ + item + for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: From 4300ebc8aa245bc175f933945c6269ed945039a3 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 23 May 2025 15:10:16 +0800 Subject: [PATCH 029/155] fix: remove provide type --- api/core/plugin/impl/datasource.py | 56 +++++++++--------------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index ebe08bd7eb..80d868c1af 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -97,7 +97,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_website_crawl", + f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", GetWebsiteCrawlResponse, data={ "user_id": user_id, @@ -136,7 +136,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_pages", + f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", GetOnlineDocumentPagesResponse, data={ "user_id": user_id, @@ -175,7 +175,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_page_content", + f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", GetOnlineDocumentPageContentResponse, data={ "user_id": user_id, @@ -233,45 +233,23 @@ class PluginDatasourceManager(BasePluginClient): "name": "langgenius/file/file", "plugin_id": "langgenius/file", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", - "description": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - }, + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", - "label": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - }, + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "type": "datasource", "team_credentials": {}, "is_team_authorization": False, "allow_delete": True, - "datasources": [{ - "author": "langgenius", - "name": "upload_file", - "label": { - "en_US": "File", - "zh_Hans": "File", - "pt_BR": "File", - "ja_JP": "File" - }, - "description": { - "en_US": "File", - "zh_Hans": "File", - "pt_BR": "File", - "ja_JP": "File." - }, - "parameters": [], - "labels": [ - "search" - ], - "output_schema": None - }], - "labels": [ - "search" - ] + "datasources": [ + { + "author": "langgenius", + "name": "upload_file", + "label": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File"}, + "description": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File."}, + "parameters": [], + "labels": ["search"], + "output_schema": None, + } + ], + "labels": ["search"], } From a7d5f2f53bcebdc16187afba62d45da7e55b4616 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 23 May 2025 15:10:56 +0800 Subject: [PATCH 030/155] apply ruff --- api/core/datasource/entities/api_entities.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 3b224c9e64..81771719ea 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType class DatasourceApiEntity(BaseModel): From 4460d96e5843956d9ab0f78ebb3780383ddc9bf1 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 23 May 2025 15:11:40 +0800 Subject: [PATCH 031/155] feat: add oauth schema --- api/core/plugin/entities/oauth.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 api/core/plugin/entities/oauth.py diff --git a/api/core/plugin/entities/oauth.py b/api/core/plugin/entities/oauth.py new file mode 100644 index 0000000000..d284b82728 --- /dev/null +++ b/api/core/plugin/entities/oauth.py @@ -0,0 +1,21 @@ +from collections.abc import Sequence + +from pydantic import BaseModel, Field + +from core.entities.provider_entities import ProviderConfig + + +class OAuthSchema(BaseModel): + """ + OAuth schema + """ + + client_schema: Sequence[ProviderConfig] = Field( + default_factory=list, + description="client schema like client_id, client_secret, etc.", + ) + + credentials_schema: Sequence[ProviderConfig] = Field( + default_factory=list, + description="credentials schema like access_token, refresh_token, etc.", + ) From a49942b9495d9ec052ac265a47c13b70b4bfe402 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 23 May 2025 15:12:31 +0800 Subject: [PATCH 032/155] fix: rename first_step_parameters --- api/core/workflow/nodes/datasource/datasource_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 4f26ec546e..d25784b781 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -68,7 +68,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return # get parameters - datasource_parameters = datasource_runtime.entity.first_step_parameters + datasource_parameters = datasource_runtime.entity.parameters parameters = self._generate_parameters( datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, From 64d997fdb044976c935bcffd982a74bd86354d8a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 23 May 2025 15:55:41 +0800 Subject: [PATCH 033/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 67 +++++++------------ .../app/apps/pipeline/pipeline_generator.py | 7 +- .../datasource/__base/datasource_provider.py | 20 +----- api/core/datasource/datasource_manager.py | 62 ++++++++++------- .../entities/datasource_entities.py | 9 +-- .../local_file/local_file_plugin.py | 9 --- .../local_file/local_file_provider.py | 4 +- .../online_document/online_document_plugin.py | 9 --- .../online_document_provider.py | 10 ++- .../website_crawl/website_crawl_plugin.py | 14 +--- .../website_crawl/website_crawl_provider.py | 10 ++- api/core/tools/entities/tool_entities.py | 2 - .../nodes/datasource/datasource_node.py | 57 +++++++++------- api/models/workflow.py | 2 +- .../rag_pipeline/pipeline_generate_service.py | 27 +++----- api/services/rag_pipeline/rag_pipeline.py | 65 +++++++++++++----- 16 files changed, 176 insertions(+), 198 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c0406940a7..bdd40fcabe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from models.model import EndUser import services from configs import dify_config from controllers.console import api @@ -44,7 +45,6 @@ from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) @@ -243,6 +243,7 @@ class DraftRagPipelineRunApi(Resource): parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") args = parser.parse_args() try: @@ -313,13 +314,20 @@ class RagPipelineDatasourceNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") args = parser.parse_args() inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=args.get("datasource_type"), ) return result @@ -648,40 +656,6 @@ class RagPipelineByIdApi(Resource): return workflow - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def delete(self, pipeline: Pipeline, workflow_id: str): - """ - Delete workflow - """ - # Check permission - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - rag_pipeline_service = RagPipelineService() - - # Create a session and manage the transaction - with Session(db.engine) as session: - try: - rag_pipeline_service.delete_workflow( - session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id - ) - # Commit the transaction in the controller - session.commit() - except WorkflowInUseError as e: - abort(400, description=str(e)) - except DraftWorkflowDeletionError as e: - abort(400, description=str(e)) - except ValueError as e: - raise NotFound(str(e)) - - return None, 204 - class PublishedRagPipelineSecondStepApi(Resource): @setup_required @@ -695,8 +669,12 @@ class PublishedRagPipelineSecondStepApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = request.args.get("node_id", required=True, type=str) - + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) return { @@ -716,7 +694,12 @@ class DraftRagPipelineSecondStepApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = request.args.get("node_id", required=True, type=str) + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) @@ -777,9 +760,11 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): run_id = str(run_id) rag_pipeline_service = RagPipelineService() + user = cast("Account | EndUser", current_user) node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( pipeline=pipeline, run_id=run_id, + user=user, ) return {"data": node_executions} @@ -875,9 +860,9 @@ api.add_resource( ) api.add_resource( PublishedRagPipelineSecondStepApi, - "/rag/pipelines//workflows/published/processing/paramters", + "/rag/pipelines//workflows/published/processing/parameters", ) api.add_resource( DraftRagPipelineSecondStepApi, - "/rag/pipelines//workflows/draft/processing/paramters", + "/rag/pipelines//workflows/draft/processing/parameters", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 1e880c700c..c1aa9747d2 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -99,6 +99,7 @@ class PipelineGenerator(BaseAppGenerator): ) inputs: Mapping[str, Any] = args["inputs"] + start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) @@ -118,7 +119,7 @@ class PipelineGenerator(BaseAppGenerator): position=position, account=user, batch=batch, - document_form=pipeline.dataset.doc_form, + document_form=pipeline.dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -231,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator): def single_iteration_generate( self, - app_model: App, + pipeline: Pipeline, workflow: Workflow, node_id: str, user: Account | EndUser, @@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index 1544270d7a..bae39dc8c7 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin -from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.entities.provider_entities import ProviderConfig from core.plugin.impl.tool import PluginToolManager @@ -11,9 +10,11 @@ from core.tools.errors import ToolProviderCredentialValidationError class DatasourcePluginProviderController(ABC): entity: DatasourceProviderEntityWithPlugin + tenant_id: str - def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: self.entity = entity + self.tenant_id = tenant_id @property def need_credentials(self) -> bool: @@ -51,21 +52,6 @@ class DatasourcePluginProviderController(ABC): """ pass - def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore - """ - get all datasources - """ - return [ - DatasourcePlugin( - entity=datasource_entity, - runtime=DatasourceRuntime(tenant_id=self.tenant_id), - tenant_id=self.tenant_id, - icon=self.entity.identity.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) - for datasource_entity in self.entity.datasources - ] - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ validate the format of the credentials of the provider and set the default value if needed diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index c865b557f9..8c74aeb320 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,7 +6,11 @@ import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.common_entities import I18nObject +from core.datasource.entities.datasource_entities import DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError +from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController +from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController +from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.plugin.impl.datasource import PluginDatasourceManager logger = logging.getLogger(__name__) @@ -19,7 +23,9 @@ class DatasourceManager: _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod - def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController: + def get_datasource_plugin_provider( + cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType + ) -> DatasourcePluginProviderController: """ get the datasource plugin provider """ @@ -40,12 +46,30 @@ class DatasourceManager: if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") - controller = DatasourcePluginProviderController( - entity=provider_entity.declaration, - plugin_id=provider_entity.plugin_id, - plugin_unique_identifier=provider_entity.plugin_unique_identifier, - tenant_id=tenant_id, - ) + match (datasource_type): + case DatasourceProviderType.ONLINE_DOCUMENT: + controller = OnlineDocumentDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case DatasourceProviderType.WEBSITE_CRAWL: + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case DatasourceProviderType.LOCAL_FILE: + controller = LocalFileDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") datasource_plugin_providers[provider] = controller @@ -57,6 +81,7 @@ class DatasourceManager: provider_id: str, datasource_name: str, tenant_id: str, + datasource_type: DatasourceProviderType, ) -> DatasourcePlugin: """ get the datasource runtime @@ -68,21 +93,10 @@ class DatasourceManager: :return: the datasource plugin """ - return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) + return cls.get_datasource_plugin_provider( + provider_id, + tenant_id, + datasource_type, + ).get_datasource(datasource_name) + - @classmethod - def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: - """ - list all the datasource providers - """ - manager = PluginDatasourceManager() - provider_entities = manager.fetch_datasource_providers(tenant_id) - return [ - DatasourcePluginProviderController( - entity=provider.declaration, - plugin_id=provider.plugin_id, - plugin_unique_identifier=provider.plugin_unique_identifier, - tenant_id=tenant_id, - ) - for provider in provider_entities - ] diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 7b3fadfee8..e9f73d3c18 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -251,7 +251,7 @@ class GetOnlineDocumentPageContentRequest(BaseModel): Get online document page content request """ - online_document_info_list: list[OnlineDocumentInfo] + online_document_info: OnlineDocumentInfo class OnlineDocumentPageContent(BaseModel): @@ -259,6 +259,7 @@ class OnlineDocumentPageContent(BaseModel): Online document page content """ + workspace_id: str = Field(..., description="The workspace id") page_id: str = Field(..., description="The page id") content: str = Field(..., description="The content of the page") @@ -268,7 +269,7 @@ class GetOnlineDocumentPageContentResponse(BaseModel): Get online document page content response """ - result: list[OnlineDocumentPageContent] + result: OnlineDocumentPageContent class GetWebsiteCrawlRequest(BaseModel): @@ -286,7 +287,7 @@ class WebSiteInfo(BaseModel): """ source_url: str = Field(..., description="The url of the website") - markdown: str = Field(..., description="The markdown of the website") + content: str = Field(..., description="The content of the website") title: str = Field(..., description="The title of the website") description: str = Field(..., description="The description of the website") @@ -296,4 +297,4 @@ class GetWebsiteCrawlResponse(BaseModel): Get website crawl response """ - result: list[WebSiteInfo] + result: WebSiteInfo diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index a9dced1186..45f4777f44 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -26,12 +26,3 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): def datasource_provider_type(self) -> DatasourceProviderType: return DatasourceProviderType.LOCAL_FILE - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, - runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py index 79f885dda5..b2b6f51dd3 100644 --- a/api/core/datasource/local_file/local_file_provider.py +++ b/api/core/datasource/local_file/local_file_provider.py @@ -8,15 +8,13 @@ from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlug class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin - tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: - super().__init__(entity) - self.tenant_id = tenant_id + super().__init__(entity, tenant_id) self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 197d85ef59..07d7a25160 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -69,12 +69,3 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def datasource_provider_type(self) -> DatasourceProviderType: return DatasourceProviderType.ONLINE_DOCUMENT - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, - runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py index 06572880b8..a128b479f4 100644 --- a/api/core/datasource/online_document/online_document_provider.py +++ b/api/core/datasource/online_document/online_document_provider.py @@ -1,20 +1,18 @@ -from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin - tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: - super().__init__(entity) - self.tenant_id = tenant_id + super().__init__(entity, tenant_id) self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier @@ -25,7 +23,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC """ return DatasourceProviderType.ONLINE_DOCUMENT - def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -41,7 +39,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return DatasourcePlugin( + return OnlineDocumentDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 8454d1636e..5f92551198 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import ( GetWebsiteCrawlResponse, ) from core.plugin.impl.datasource import PluginDatasourceManager -from core.plugin.utils.converter import convert_parameters_to_plugin_format class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): @@ -38,9 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): ) -> GetWebsiteCrawlResponse: manager = PluginDatasourceManager() - datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - - return manager.invoke_first_step( + return manager.get_website_crawl( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, @@ -52,12 +49,3 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): def datasource_provider_type(self) -> DatasourceProviderType: return DatasourceProviderType.WEBSITE_CRAWL - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, - runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 9c6bcdb7c2..95f05fcee0 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,20 +1,18 @@ -from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin - tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: - super().__init__(entity) - self.tenant_id = tenant_id + super().__init__(entity, tenant_id) self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier @@ -25,7 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon """ return DatasourceProviderType.WEBSITE_CRAWL - def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -41,7 +39,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return DatasourcePlugin( + return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 9884d93e9d..37375f4a71 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,7 +7,6 @@ from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from core.entities.provider_entities import ProviderConfig -from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -350,7 +349,6 @@ class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider") class ToolProviderEntityWithPlugin(ToolProviderEntity): diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d25784b781..612c5a5a74 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -4,6 +4,9 @@ from typing import Any, cast from core.datasource.entities.datasource_entities import ( DatasourceParameter, DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetWebsiteCrawlRequest, GetWebsiteCrawlResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin @@ -54,6 +57,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): provider_id=node_data.provider_id, datasource_name=node_data.datasource_name, tenant_id=self.tenant_id, + datasource_type=DatasourceProviderType(node_data.provider_type), ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -82,38 +86,43 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) try: - # TODO: handle result if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - result = datasource_runtime._get_online_document_page_content( - user_id=self.user_id, - datasource_parameters=parameters, - provider_type=node_data.provider_type, + online_document_result: GetOnlineDocumentPageContentResponse = ( + datasource_runtime._get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), + provider_type=datasource_runtime.datasource_provider_type(), + ) ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "result": result.result.model_dump(), - "datasource_type": datasource_runtime.datasource_provider_type, - }, + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "online_document": online_document_result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) ) elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( user_id=self.user_id, - datasource_parameters=parameters, - provider_type=node_data.provider_type, + datasource_parameters=GetWebsiteCrawlRequest(**parameters), + provider_type=datasource_runtime.datasource_provider_type(), ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "result": result.result.model_dump(), - "datasource_type": datasource_runtime.datasource_provider_type, - }, + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "website": website_crawl_result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) ) else: raise DatasourceNodeError( diff --git a/api/models/workflow.py b/api/models/workflow.py index 13ef16442c..b428b1e5db 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -360,7 +360,7 @@ class Workflow(Base): ) @property - def rag_pipeline_variables(self) -> Sequence[Variable]: + def rag_pipeline_variables(self) -> list[dict]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._rag_pipeline_variables is None: self._rag_pipeline_variables = "{}" diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 089519dd0d..14594be351 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -2,12 +2,11 @@ from collections.abc import Mapping from typing import Any, Union from configs import dify_config -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from models.dataset import Pipeline -from models.model import Account, App, AppMode, EndUser +from models.model import Account, App, EndUser from models.workflow import Workflow from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -57,23 +56,15 @@ class PipelineGenerateService: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming - ) + def generate_single_iteration( + cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True + ): + workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().single_iteration_generate( + pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) - elif app_model.mode == AppMode.WORKFLOW.value: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming - ) - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + ) @classmethod def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a0a890aee7..bf582b9d27 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -3,7 +3,7 @@ import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user @@ -12,6 +12,9 @@ from sqlalchemy.orm import Session import contexts from configs import dify_config +from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable @@ -30,6 +33,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import EndUser from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -394,8 +398,8 @@ class RagPipelineService: return workflow_node_execution def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account - ) -> WorkflowNodeExecution: + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str + ) -> dict: """ Run published workflow datasource """ @@ -416,17 +420,36 @@ class RagPipelineService: provider_id=datasource_node_data.get("provider_id"), datasource_name=datasource_node_data.get("datasource_name"), tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), ) - result = datasource_runtime._invoke_first_step( - inputs=user_inputs, - provider_type=datasource_node_data.get("provider_type"), - user_id=account.id, - ) + if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPagesResponse = ( + datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id), + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + return { + "result": [page.model_dump() for page in online_document_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } + + elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters=GetWebsiteCrawlRequest(**user_inputs), + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": website_crawl_result.result.model_dump(), + "provider_type": datasource_node_data.get("provider_type"), + } + else: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - return { - "result": result, - "provider_type": datasource_node_data.get("provider_type"), - } def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] @@ -587,7 +610,7 @@ class RagPipelineService: return workflow - def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get second step parameters of rag pipeline """ @@ -599,7 +622,7 @@ class RagPipelineService: # get second step node rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: - return {} + return [] # get datasource provider datasource_provider_variables = [ @@ -609,7 +632,7 @@ class RagPipelineService: ] return datasource_provider_variables - def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get second step parameters of rag pipeline """ @@ -621,7 +644,7 @@ class RagPipelineService: # get second step node rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: - return {} + return [] # get datasource provider datasource_provider_variables = [ @@ -702,6 +725,7 @@ class RagPipelineService: self, pipeline: Pipeline, run_id: str, + user: Account | EndUser, ) -> list[WorkflowNodeExecution]: """ Get workflow run node execution list @@ -716,11 +740,16 @@ class RagPipelineService: # Use the repository to get the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id + session_factory=db.engine, + app_id=pipeline.id, + user=user, + triggered_from=None ) # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + # Convert domain models to database models + workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] - return list(node_executions) + return workflow_node_executions From 42fcda3dc85601c97662dfa4f63eaa84ca79e38e Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 23 May 2025 17:11:56 +0800 Subject: [PATCH 034/155] r2 --- .../datasets/rag_pipeline/rag_pipeline.py | 6 ++-- .../entities/datasource_entities.py | 34 +++++++++++-------- .../online_document/online_document_plugin.py | 3 +- .../website_crawl/website_crawl_plugin.py | 3 +- api/core/plugin/entities/plugin_daemon.py | 1 - api/core/plugin/impl/datasource.py | 8 ++--- api/services/rag_pipeline/rag_pipeline.py | 14 ++++---- 7 files changed, 37 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index cc07084dea..8c5f91cb7f 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -50,8 +50,8 @@ class PipelineTemplateDetailApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def get(self, pipeline_id: str): - pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) + def get(self, template_id: str): + pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id) return pipeline_template, 200 @@ -120,7 +120,7 @@ api.add_resource( ) api.add_resource( PipelineTemplateDetailApi, - "/rag/pipeline/templates/", + "/rag/pipeline/templates/", ) api.add_resource( CustomizedPipelineTemplateApi, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index e9f73d3c18..22ec5c3d23 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -4,6 +4,8 @@ from typing import Any, Optional from pydantic import BaseModel, Field, ValidationInfo, field_validator +from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -13,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderEntity +from core.tools.entities.tool_entities import ToolLabelEnum, ToolProviderEntity class DatasourceProviderType(enum.StrEnum): @@ -118,29 +120,36 @@ class DatasourceIdentity(BaseModel): icon: Optional[str] = None -class DatasourceDescription(BaseModel): - human: I18nObject = Field(..., description="The description presented to the user") - llm: str = Field(..., description="The description presented to the LLM") - - class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) - description: Optional[DatasourceDescription] = None + description: I18nObject = Field(..., description="The label of the datasource") output_schema: Optional[dict] = None - has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: return v or [] +class DatasourceProviderIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + description: I18nObject = Field(..., description="The description of the tool") + icon: str = Field(..., description="The icon of the tool") + label: I18nObject = Field(..., description="The label of the tool") + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) -class DatasourceProviderEntity(ToolProviderEntity): + +class DatasourceProviderEntity(BaseModel): """ Datasource provider entity """ - + identity: DatasourceProviderIdentity + credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None provider_type: DatasourceProviderType @@ -202,7 +211,6 @@ class GetOnlineDocumentPagesRequest(BaseModel): Get online document pages request """ - tenant_id: str = Field(..., description="The tenant id") class OnlineDocumentPageIcon(BaseModel): @@ -276,8 +284,6 @@ class GetWebsiteCrawlRequest(BaseModel): """ Get website crawl request """ - - url: str = Field(..., description="The url of the website") crawl_parameters: dict = Field(..., description="The crawl parameters") @@ -297,4 +303,4 @@ class GetWebsiteCrawlResponse(BaseModel): Get website crawl response """ - result: WebSiteInfo + result: list[WebSiteInfo] diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 07d7a25160..7809ac2a89 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,3 +1,4 @@ +from typing import Any, Mapping from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -34,7 +35,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def _get_online_document_pages( self, user_id: str, - datasource_parameters: GetOnlineDocumentPagesRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetOnlineDocumentPagesResponse: manager = PluginDatasourceManager() diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 5f92551198..e657fceb9c 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,3 +1,4 @@ +from typing import Any, Mapping from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -32,7 +33,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): def _get_website_crawl( self, user_id: str, - datasource_parameters: GetWebsiteCrawlRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetWebsiteCrawlResponse: manager = PluginDatasourceManager() diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 3b0defbb08..90086173fa 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -52,7 +52,6 @@ class PluginDatasourceProviderEntity(BaseModel): provider: str plugin_unique_identifier: str plugin_id: str - author: str declaration: DatasourceProviderEntityWithPlugin diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 80d868c1af..430a9a6c01 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,12 +1,10 @@ -from typing import Any +from typing import Any, Mapping from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.datasource.entities.datasource_entities import ( GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlRequest, GetWebsiteCrawlResponse, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -86,7 +84,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: GetWebsiteCrawlRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetWebsiteCrawlResponse: """ @@ -125,7 +123,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: GetOnlineDocumentPagesRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetOnlineDocumentPagesResponse: """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index bf582b9d27..3bee0538ab 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -67,15 +67,15 @@ class RagPipelineService: return result.get("pipeline_templates") @classmethod - def get_pipeline_template_detail(cls, pipeline_id: str) -> Optional[dict]: + def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: """ Get pipeline template detail. - :param pipeline_id: pipeline id + :param template_id: template id :return: """ mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) - result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(pipeline_id) + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) return result @classmethod @@ -427,7 +427,7 @@ class RagPipelineService: online_document_result: GetOnlineDocumentPagesResponse = ( datasource_runtime._get_online_document_pages( user_id=account.id, - datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id), + datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) ) @@ -440,11 +440,11 @@ class RagPipelineService: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( user_id=account.id, - datasource_parameters=GetWebsiteCrawlRequest(**user_inputs), + datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) return { - "result": website_crawl_result.result.model_dump(), + "result": [result.model_dump() for result in website_crawl_result.result], "provider_type": datasource_node_data.get("provider_type"), } else: From 70d2c781763091462d556d0c187414c20301132a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 23 May 2025 17:13:09 +0800 Subject: [PATCH 035/155] r2 --- .../datasource/online_document/online_document_plugin.py | 5 +++-- api/core/datasource/website_crawl/website_crawl_plugin.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 7809ac2a89..b809ad6abf 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,4 +1,6 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any + from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -6,7 +8,6 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, ) from core.plugin.impl.datasource import PluginDatasourceManager diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index e657fceb9c..b1b6489197 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,10 +1,11 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any + from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceProviderType, - GetWebsiteCrawlRequest, GetWebsiteCrawlResponse, ) from core.plugin.impl.datasource import PluginDatasourceManager From a15bf8e8fe468cee6d5b95e45ebac4eda05ee2d8 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 23 May 2025 17:35:26 +0800 Subject: [PATCH 036/155] remove output schema --- api/core/datasource/entities/datasource_entities.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 22ec5c3d23..6a9fc5d9f9 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum, ToolProviderEntity +from core.tools.entities.tool_entities import ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -124,13 +124,13 @@ class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The label of the datasource") - output_schema: Optional[dict] = None @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: return v or [] + class DatasourceProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -142,11 +142,12 @@ class DatasourceProviderIdentity(BaseModel): description="The tags of the tool", ) - + class DatasourceProviderEntity(BaseModel): """ Datasource provider entity """ + identity: DatasourceProviderIdentity credentials_schema: list[ProviderConfig] = Field(default_factory=list) oauth_schema: Optional[OAuthSchema] = None @@ -212,7 +213,6 @@ class GetOnlineDocumentPagesRequest(BaseModel): """ - class OnlineDocumentPageIcon(BaseModel): """ Online document page icon @@ -284,6 +284,7 @@ class GetWebsiteCrawlRequest(BaseModel): """ Get website crawl request """ + crawl_parameters: dict = Field(..., description="The crawl parameters") From 6123f1ab2162418e21957d4dcaefde0d025036ef Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 23 May 2025 19:22:50 +0800 Subject: [PATCH 037/155] refactor: reorganize imports and fix datasource endpoint URL --- .../console/datasets/rag_pipeline/rag_pipeline_workflow.py | 2 +- api/core/plugin/impl/datasource.py | 5 +++-- api/services/rag_pipeline/rag_pipeline.py | 6 +++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index bdd40fcabe..b25a362674 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from models.model import EndUser import services from configs import dify_config from controllers.console import api @@ -40,6 +39,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline +from models.model import EndUser from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 430a9a6c01..aa8b1ad4d6 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.datasource.entities.datasource_entities import ( @@ -63,7 +64,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response( "GET", - f"plugin/{tenant_id}/management/datasources", + f"plugin/{tenant_id}/management/datasource", PluginDatasourceProviderEntity, params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, transformer=transformer, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 3bee0538ab..461e694a1b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -12,7 +12,11 @@ from sqlalchemy.orm import Session import contexts from configs import dify_config -from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse +from core.datasource.entities.datasource_entities import ( + DatasourceProviderType, + GetOnlineDocumentPagesResponse, + GetWebsiteCrawlResponse, +) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.model_runtime.utils.encoders import jsonable_encoder From 6d547447d3723b91b858c309b297170041e4ac9b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 23 May 2025 19:30:48 +0800 Subject: [PATCH 038/155] r2 --- .../rag_pipeline/rag_pipeline_datasets.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 9 ++- .../app/apps/pipeline/pipeline_generator.py | 72 ++++++++++++------- api/core/app/apps/pipeline/pipeline_runner.py | 51 ++++++++++++- api/core/app/entities/app_invoke_entities.py | 8 +-- api/core/workflow/enums.py | 2 + .../workflow/graph_engine/entities/graph.py | 5 +- .../nodes/datasource/datasource_node.py | 35 +++++---- api/services/dataset_service.py | 7 +- api/services/rag_pipeline/rag_pipeline.py | 4 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 34 +++++---- 11 files changed, 157 insertions(+), 72 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 6676deb63a..1a4e9240b6 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -161,7 +161,7 @@ class CreateEmptyRagPipelineDatasetApi(Resource): args = parser.parse_args() dataset = DatasetService.create_empty_rag_pipeline_dataset( tenant_id=current_user.current_tenant_id, - rag_pipeline_dataset_create_entity=args, + rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args), ) return marshal(dataset, dataset_detail_fields), 201 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index bdd40fcabe..f6238bf143 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from models.model import EndUser import services from configs import dify_config from controllers.console import api @@ -40,6 +39,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline +from models.model import EndUser from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService @@ -242,7 +242,7 @@ class DraftRagPipelineRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") args = parser.parse_args() @@ -320,6 +320,9 @@ class RagPipelineDatasourceNodeRunApi(Resource): inputs = args.get("inputs") if inputs == None: raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.run_datasource_workflow_node( @@ -327,7 +330,7 @@ class RagPipelineDatasourceNodeRunApi(Resource): node_id=node_id, user_inputs=inputs, account=current_user, - datasource_type=args.get("datasource_type"), + datasource_type=datasource_type, ) return result diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index c1aa9747d2..ccc227f3f4 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat from extensions.ext_database import db from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline +from models.model import AppMode from services.dataset_service import DocumentService logger = logging.getLogger(__name__) @@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, @@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator): for datasource_info in datasource_info_list: workflow_run_id = str(uuid.uuid4()) document_id = None + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") if invoke_from == InvokeFrom.PUBLISHED: + position = DocumentService.get_documents_position(pipeline.dataset_id) position = DocumentService.get_documents_position(pipeline.dataset_id) document = self._build_document( tenant_id=pipeline.tenant_id, dataset_id=pipeline.dataset_id, - built_in_field_enabled=pipeline.dataset.built_in_field_enabled, + built_in_field_enabled=dataset.built_in_field_enabled, datasource_type=datasource_type, datasource_info=datasource_info, created_from="rag-pipeline", position=position, account=user, batch=batch, - document_form=pipeline.dataset.chunk_structure, + document_form=dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator): # init application generate entity application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - pipline_config=pipeline_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, datasource_type=datasource_type, datasource_info=datasource_info, - dataset_id=pipeline.dataset_id, + dataset_id=dataset.id, + start_node_id=start_node_id, batch=batch, document_id=document_id, inputs=self._prepare_user_inputs( @@ -160,17 +167,28 @@ class PipelineGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - - return self._generate( - pipeline=pipeline, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - workflow_thread_pool_id=workflow_thread_pool_id, - ) + if invoke_from == InvokeFrom.DEBUGGER: + return self._generate( + pipeline=pipeline, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + else: + self._generate( + pipeline=pipeline, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) def _generate( self, @@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator): task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=pipeline.mode, + app_mode=AppMode.RAG_PIPELINE, ) # new thread @@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( + application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=app_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=args["datasource_type"], + datasource_info=args["datasource_info"], + dataset_id=pipeline.dataset_id, + batch=args["batch"], + document_id=args["document_id"], inputs={}, files=[], user_id=user.id, @@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( - app_model=app_model, + pipeline=pipeline, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, @@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator): def single_loop_generate( self, - app_model: App, + pipeline: Pipeline, workflow: Workflow, node_id: str, user: Account | EndUser, @@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( - app_model=app_model, + pipeline=pipeline, workflow=workflow, user=user, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 1395a47d88..80b724dd20 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -1,5 +1,6 @@ import logging -from typing import Optional, cast +from collections.abc import Mapping +from typing import Any, Optional, cast from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager @@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import ( from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.dataset import Pipeline @@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, SystemVariableKey.BATCH: self.application_generate_entity.batch, SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, + SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, + SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, } variable_pool = VariablePool( @@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner): ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_rag_pipeline_graph( + graph_config=workflow.graph_dict, + start_node_id=self.application_generate_entity.start_node_id, + ) # RUN WORKFLOW workflow_entry = WorkflowEntry( @@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner): # return workflow return workflow + + def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph: + """ + Init pipeline graph + """ + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + nodes = graph_config.get("nodes", []) + edges = graph_config.get("edges", []) + real_run_nodes = [] + real_edges = [] + exclude_node_ids = [] + for node in nodes: + node_id = node.get("id") + node_type = node.get("data", {}).get("type", "") + if node_type == "datasource": + if start_node_id != node_id: + exclude_node_ids.append(node_id) + continue + real_run_nodes.append(node) + for edge in edges: + if edge.get("source") in exclude_node_ids : + continue + real_edges.append(edge) + graph_config = dict(graph_config) + graph_config["nodes"] = real_run_nodes + graph_config["edges"] = real_edges + # init graph + graph = Graph.init(graph_config=graph_config) + + if not graph: + raise ValueError("graph not found in workflow") + + return graph \ No newline at end of file diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index d730704f48..4565d37d5b 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -233,14 +233,14 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): """ RAG Pipeline Application Generate Entity. """ - - # app config - pipline_config: WorkflowUIBasedAppConfig + # pipeline config + pipeline_config: WorkflowUIBasedAppConfig datasource_type: str datasource_info: Mapping[str, Any] dataset_id: str batch: str - document_id: str + document_id: Optional[str] = None + start_node_id: Optional[str] = None class SingleIterationRunEntity(BaseModel): """ diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 34d17c880a..0e210c1389 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -18,3 +18,5 @@ class SystemVariableKey(StrEnum): DOCUMENT_ID = "document_id" BATCH = "batch" DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 8e5b1e7142..7062fc4565 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -121,6 +121,8 @@ class Graph(BaseModel): # fetch nodes that have no predecessor node root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} + + for node_config in node_configs: node_id = node_config.get("id") if not node_id: @@ -140,7 +142,8 @@ class Graph(BaseModel): ( node_config.get("id") for node_config in root_node_configs - if node_config.get("data", {}).get("type", "") == NodeType.START.value + if node_config.get("data", {}).get("type", "") == NodeType.START.value + or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value ), None, ) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 612c5a5a74..f5e34f5998 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -6,11 +6,8 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, - GetWebsiteCrawlRequest, - GetWebsiteCrawlResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.file import File from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment @@ -42,22 +39,23 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): """ node_data = cast(DatasourceNodeData, self.node_data) - - # fetch datasource icon - datasource_info = { - "provider_id": node_data.provider_id, - "plugin_unique_identifier": node_data.plugin_unique_identifier, - } + variable_pool = self.graph_runtime_state.variable_pool # get datasource runtime try: from core.datasource.datasource_manager import DatasourceManager + datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + + datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + if datasource_type is None: + raise DatasourceNodeError("Datasource type is not set") + datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id=node_data.provider_id, datasource_name=node_data.datasource_name, tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType(node_data.provider_type), + datasource_type=DatasourceProviderType(datasource_type), ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -75,12 +73,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_parameters = datasource_runtime.entity.parameters parameters = self._generate_parameters( datasource_parameters=datasource_parameters, - variable_pool=self.graph_runtime_state.variable_pool, + variable_pool=variable_pool, node_data=self.node_data, ) parameters_for_log = self._generate_parameters( datasource_parameters=datasource_parameters, - variable_pool=self.graph_runtime_state.variable_pool, + variable_pool=variable_pool, node_data=self.node_data, for_log=True, ) @@ -106,20 +104,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): }, ) ) - elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=self.user_id, - datasource_parameters=GetWebsiteCrawlRequest(**parameters), - provider_type=datasource_runtime.datasource_provider_type(), + elif ( + datasource_runtime.datasource_provider_type in ( + DatasourceProviderType.WEBSITE_CRAWL, + DatasourceProviderType.LOCAL_FILE, ) + ): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "website": website_crawl_result.result.model_dump(), + "website": datasource_info, "datasource_type": datasource_runtime.datasource_provider_type, }, ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8a87964276..62a16c56ce 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Optional, cast from flask_login import current_user from sqlalchemy import func, select @@ -298,13 +298,14 @@ class DatasetService: description=rag_pipeline_dataset_create_entity.description, permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", - runtime_mode="rag_pipeline", + runtime_mode="rag-pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, ) with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) + account = cast(Account, current_user) rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( - account=current_user, + account=account, import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, dataset=dataset, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 3bee0538ab..08bb10b5d4 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -59,12 +59,12 @@ class RagPipelineService: if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") - return result.get("pipeline_templates") + return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])] else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) - return result.get("pipeline_templates") + return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])] @classmethod def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 19c7d37f6e..acd364f6cd 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: class RagPipelinePendingData(BaseModel): import_mode: str yaml_content: str - name: str | None - description: str | None - icon_type: str | None - icon: str | None - icon_background: str | None pipeline_id: str | None @@ -302,10 +297,6 @@ class RagPipelineDslService: dataset.runtime_mode = "rag_pipeline" dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.index_method.indexing_technique == "high_quality": - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore - knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore - ) dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( @@ -445,10 +436,28 @@ class RagPipelineDslService: dataset.runtime_mode = "rag_pipeline" dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.index_method.indexing_technique == "high_quality": - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore - knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name + == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + DatasetCollectionBinding.model_name + == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + DatasetCollectionBinding.type == "dataset", + ) + .order_by(DatasetCollectionBinding.created_at) + .first() ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type="dataset", + ) + db.session.add(dataset_collection_binding) + db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( @@ -602,7 +611,6 @@ class RagPipelineDslService: rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), unique_hash=unique_hash, account=account, environment_variables=environment_variables, From ec1c4efca94d06fb25364a4ab53b9e9dbf954779 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Sun, 25 May 2025 23:09:01 +0800 Subject: [PATCH 039/155] r2 --- .../datasource/__base/datasource_plugin.py | 2 +- .../datasource/__base/datasource_provider.py | 4 +- .../local_file/local_file_plugin.py | 2 +- .../online_document/online_document_plugin.py | 2 +- .../website_crawl/website_crawl_plugin.py | 2 +- .../website_crawl/website_crawl_provider.py | 2 +- api/core/plugin/impl/datasource.py | 69 +++++++--- .../workflow/nodes/datasource/__init__.py | 2 +- .../nodes/datasource/datasource_node.py | 130 ++++++++++-------- .../workflow/nodes/datasource/entities.py | 27 +--- .../knowledge_index/knowledge_index_node.py | 5 +- api/core/workflow/nodes/node_mapping.py | 10 ++ 12 files changed, 147 insertions(+), 110 deletions(-) diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index d8681b6491..5a13d17843 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -20,7 +20,7 @@ class DatasourcePlugin(ABC): self.runtime = runtime @abstractmethod - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: """ returns the type of the datasource provider """ diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index bae39dc8c7..045ca64872 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError class DatasourcePluginProviderController(ABC): - entity: DatasourceProviderEntityWithPlugin + entity: DatasourceProviderEntityWithPlugin | None tenant_id: str - def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None: self.entity = entity self.tenant_id = tenant_id diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index 45f4777f44..82da10d663 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -24,5 +24,5 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: return DatasourceProviderType.LOCAL_FILE diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index b809ad6abf..f94031656e 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -69,5 +69,5 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: return DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index b1b6489197..e8256b3282 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -49,5 +49,5 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def datasource_provider_type(self) -> DatasourceProviderType: + def datasource_provider_type(self) -> str: return DatasourceProviderType.WEBSITE_CRAWL diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 95f05fcee0..d9043702d2 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -10,7 +10,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon plugin_unique_identifier: str def __init__( - self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + self, entity: DatasourceProviderEntityWithPlugin | None, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: super().__init__(entity, tenant_id) self.plugin_id = plugin_id diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index aa8b1ad4d6..645e067e4c 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -6,7 +6,7 @@ from core.datasource.entities.datasource_entities import ( GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + GetWebsiteCrawlResponse, DatasourceProviderEntity, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -17,7 +17,7 @@ from core.plugin.impl.base import BasePluginClient class PluginDatasourceManager(BasePluginClient): - def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]: + def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: """ Fetch datasource providers for the given tenant. """ @@ -46,12 +46,15 @@ class PluginDatasourceManager(BasePluginClient): # for datasource in provider.declaration.datasources: # datasource.identity.provider = provider.declaration.identity.name - return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())] + return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())] def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ Fetch datasource provider for the given tenant and plugin. """ + if provider == "langgenius/file/file": + return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + tool_provider_id = ToolProviderID(provider) def transformer(json_response: dict[str, Any]) -> dict: @@ -218,6 +221,7 @@ class PluginDatasourceManager(BasePluginClient): "X-Plugin-ID": tool_provider_id.plugin_id, "Content-Type": "application/json", }, + ) for resp in response: @@ -228,27 +232,48 @@ class PluginDatasourceManager(BasePluginClient): def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", - "author": "langgenius", - "name": "langgenius/file/file", "plugin_id": "langgenius/file", + "provider": "langgenius", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", - "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, - "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", - "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, - "type": "datasource", - "team_credentials": {}, - "is_team_authorization": False, - "allow_delete": True, - "datasources": [ - { + "declaration": { + "identity": { "author": "langgenius", - "name": "upload_file", - "label": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File"}, - "description": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File."}, + "name": "langgenius/file/file", + "label": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", + "description": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + } + }, + "credentials_schema": [], + "provider_type": "local_file", + "datasources": [{ + "identity": { + "author": "langgenius", + "name": "local_file", + "provider": "langgenius", + "label": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + } + }, "parameters": [], - "labels": ["search"], - "output_schema": None, - } - ], - "labels": ["search"], + "description": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + } + }] + } } diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py index cee9e5a895..f6ec44cb77 100644 --- a/api/core/workflow/nodes/datasource/__init__.py +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -1,3 +1,3 @@ -from .tool_node import ToolNode +from .datasource_node import DatasourceNode __all__ = ["DatasourceNode"] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index f5e34f5998..198e167341 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -40,14 +40,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): node_data = cast(DatasourceNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool - + datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + if not datasource_type: + raise DatasourceNodeError("Datasource type is not set") + datasource_type = datasource_type.value + datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + if not datasource_info: + raise DatasourceNodeError("Datasource info is not set") + datasource_info = datasource_info.value # get datasource runtime try: from core.datasource.datasource_manager import DatasourceManager - datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) - datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") @@ -84,47 +89,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) try: - if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPageContentResponse = ( - datasource_runtime._get_online_document_page_content( - user_id=self.user_id, - datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), - provider_type=datasource_runtime.datasource_provider_type(), + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPageContentResponse = ( + datasource_runtime._get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), + provider_type=datasource_type, + ) ) - ) - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "online_document": online_document_result.result.model_dump(), - "datasource_type": datasource_runtime.datasource_provider_type, - }, + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "online_document": online_document_result.result.model_dump(), + "datasource_type": datasource_type, + }, + ) ) - ) - elif ( - datasource_runtime.datasource_provider_type in ( - DatasourceProviderType.WEBSITE_CRAWL, - DatasourceProviderType.LOCAL_FILE, - ) - ): - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "website": datasource_info, - "datasource_type": datasource_runtime.datasource_provider_type, - }, + case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "website": datasource_info, + "datasource_type": datasource_type, + }, + ) + ) + case DatasourceProviderType.LOCAL_FILE: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file": datasource_info, + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) + ) + case _: + raise DatasourceNodeError( + f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" ) - ) - else: - raise DatasourceNodeError( - f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" - ) except PluginDaemonClientSideError as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -170,23 +183,24 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} result: dict[str, Any] = {} - for parameter_name in node_data.datasource_parameters: - parameter = datasource_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - datasource_input = node_data.datasource_parameters[parameter_name] - if datasource_input.type == "variable": - variable = variable_pool.get(datasource_input.value) - if variable is None: - raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") - parameter_value = variable.value - elif datasource_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(datasource_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") - result[parameter_name] = parameter_value + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + parameter = datasource_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + datasource_input = node_data.datasource_parameters[parameter_name] + if datasource_input.type == "variable": + variable = variable_pool.get(datasource_input.value) + if variable is None: + raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") + parameter_value = variable.value + elif datasource_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(datasource_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") + result[parameter_name] = parameter_value return result diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 68aa9fa34c..212184bb81 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Union +from typing import Any, Literal, Union, Optional from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo @@ -9,30 +9,17 @@ from core.workflow.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): provider_id: str provider_name: str # redundancy - datasource_name: str - tool_label: str # redundancy - datasource_configurations: dict[str, Any] + provider_type: str + datasource_name: Optional[str] = "local_file" + datasource_configurations: dict[str, Any] | None = None plugin_unique_identifier: str | None = None # redundancy - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceInput(BaseModel): # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] + value: Optional[Union[Any, list[str]]] = None + type: Optional[Literal["mixed", "variable", "constant"]] = None @field_validator("type", mode="before") @classmethod @@ -51,4 +38,4 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity): raise ValueError("value must be a string, int, float, or bool") return typ - datasource_parameters: dict[str, DatasourceInput] + datasource_parameters: dict[str, DatasourceInput] | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index dac541621a..803ecc765f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -19,6 +19,7 @@ from .entities import KnowledgeIndexNodeData from .exc import ( KnowledgeIndexNodeError, ) +from ..base import BaseNode logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ default_retrieval_model = { } -class KnowledgeIndexNode(LLMNode): +class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): _node_data_cls = KnowledgeIndexNodeData # type: ignore _node_type = NodeType.KNOWLEDGE_INDEX @@ -44,7 +45,7 @@ class KnowledgeIndexNode(LLMNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - error="Query variable is not object type.", + error="Index chunk variable is not object type.", ) chunks = variable.value variables = {"chunks": chunks} diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 1f1be59542..e328c20096 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code import CodeNode +from core.workflow.nodes.datasource.datasource_node import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode from core.workflow.nodes.end import EndNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request import HttpRequestNode from core.workflow.nodes.if_else import IfElseNode from core.workflow.nodes.iteration import IterationNode, IterationStartNode +from core.workflow.nodes.knowledge_index import KnowledgeIndexNode from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.list_operator import ListOperatorNode from core.workflow.nodes.llm import LLMNode @@ -119,4 +121,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { LATEST_VERSION: AgentNode, "1": AgentNode, }, + NodeType.DATASOURCE: { + LATEST_VERSION: DatasourceNode, + "1": DatasourceNode, + }, + NodeType.KNOWLEDGE_INDEX: { + LATEST_VERSION: KnowledgeIndexNode, + "1": KnowledgeIndexNode, + }, } From 665ffbdc10f177075321d5a4f583c0038e6cfdc2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 26 May 2025 14:49:59 +0800 Subject: [PATCH 040/155] r2 --- .../console/auth/data_source_oauth.py | 24 ++++ .../rag_pipeline/rag_pipeline_workflow.py | 10 +- .../app/apps/pipeline/pipeline_generator.py | 38 +++--- api/core/app/apps/pipeline/pipeline_runner.py | 1 + api/core/plugin/impl/datasource.py | 26 ++++- api/core/workflow/enums.py | 1 + .../nodes/datasource/datasource_node.py | 109 ++++++++---------- .../workflow/nodes/datasource/entities.py | 2 +- .../knowledge_index/knowledge_index_node.py | 18 ++- api/services/dataset_service.py | 4 +- .../rag_pipeline_entities.py | 4 +- 11 files changed, 143 insertions(+), 94 deletions(-) diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 1049f864c3..8da29093fd 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api +from core.plugin.impl.datasource import PluginDatasourceManager from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -109,7 +110,30 @@ class OAuthDataSourceSync(Resource): return {"result": "success"}, 200 +class DatasourcePluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, datasource_type, datasource_name): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + # get all builtin providers + manager = PluginDatasourceManager() + # Fix: use correct method name or implement the missing method + try: + providers = manager.get_providers() # or whatever the correct method is + # Filter by datasource_type and datasource_name if needed + oauth_config = {} # Build appropriate OAuth URL response + return oauth_config + except AttributeError: + # Method doesn't exist, return empty response or implement + return {"oauth_url": None, "supported": False} + + api.add_resource(OAuthDataSource, "/oauth/data-source/") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") +api.add_resource(DatasourcePluginOauthApi, "/oauth/plugin/datasource//") \ No newline at end of file diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index f6238bf143..63fe4b7f87 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -280,6 +280,8 @@ class PublishedRagPipelineRunApi(Resource): parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") + parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) args = parser.parse_args() try: @@ -287,7 +289,7 @@ class PublishedRagPipelineRunApi(Resource): pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, streaming=True, ) @@ -469,6 +471,7 @@ class PublishedRagPipelineApi(Resource): rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: + pipeline = session.merge(pipeline) workflow = rag_pipeline_service.publish_workflow( session=session, pipeline=pipeline, @@ -478,6 +481,7 @@ class PublishedRagPipelineApi(Resource): ) pipeline.is_published = True pipeline.workflow_id = workflow.id + session.add(pipeline) workflow_created_at = TimestampField().format(workflow.created_at) session.commit() @@ -797,6 +801,10 @@ api.add_resource( DraftRagPipelineRunApi, "/rag/pipelines//workflows/draft/run", ) +api.add_resource( + PublishedRagPipelineRunApi, + "/rag/pipelines//workflows/published/run", +) api.add_resource( RagPipelineTaskStopApi, "/rag/pipelines//workflow-runs/tasks//stop", diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ccc227f3f4..55a8b96d06 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -92,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, @@ -108,23 +108,24 @@ class PipelineGenerator(BaseAppGenerator): for datasource_info in datasource_info_list: workflow_run_id = str(uuid.uuid4()) document_id = None - dataset = pipeline.dataset - if not dataset: - raise ValueError("Dataset not found") + + # Add null check for dataset + if not pipeline.dataset: + raise ValueError("Pipeline dataset is required") + if invoke_from == InvokeFrom.PUBLISHED: - position = DocumentService.get_documents_position(pipeline.dataset_id) position = DocumentService.get_documents_position(pipeline.dataset_id) document = self._build_document( tenant_id=pipeline.tenant_id, dataset_id=pipeline.dataset_id, - built_in_field_enabled=dataset.built_in_field_enabled, + built_in_field_enabled=pipeline.dataset.built_in_field_enabled, datasource_type=datasource_type, datasource_info=datasource_info, created_from="rag-pipeline", position=position, account=user, batch=batch, - document_form=dataset.chunk_structure, + document_form=pipeline.dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -136,7 +137,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline_config=pipeline_config, datasource_type=datasource_type, datasource_info=datasource_info, - dataset_id=dataset.id, + dataset_id=pipeline.dataset.id, start_node_id=start_node_id, batch=batch, document_id=document_id, @@ -274,27 +275,24 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) - # init application generate entity + # init application generate entity - use RagPipelineGenerateEntity instead application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=pipeline_config, - pipeline_config=pipeline_config, - datasource_type=args["datasource_type"], - datasource_info=args["datasource_info"], + app_config=app_config, + pipeline_config=app_config, + datasource_type=args.get("datasource_type", ""), + datasource_info=args.get("datasource_info", {}), dataset_id=pipeline.dataset_id, - batch=args["batch"], - document_id=args["document_id"], + batch=args.get("batch", ""), + document_id=args.get("document_id"), inputs={}, files=[], user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, - extras={"auto_generate_conversation_name": False}, - single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, inputs=args["inputs"] - ), + call_depth=0, workflow_run_id=str(uuid.uuid4()), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 80b724dd20..dd9eade0a5 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -104,6 +104,7 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, + SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from, } variable_pool = VariablePool( diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 645e067e4c..2ad5bdcfef 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,12 +1,11 @@ from collections.abc import Mapping from typing import Any -from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.datasource.entities.datasource_entities import ( GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, DatasourceProviderEntity, + GetWebsiteCrawlResponse, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -228,7 +227,30 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False + + def get_provider_oauth_url(self, datasource_type: str, datasource_name: str, provider: str) -> str: + """ + get the oauth url of the provider + """ + tool_provider_id = GenericProviderID(provider) + response = self._request_with_plugin_daemon_response_stream( + "GET", + f"plugin/datasource/oauth", + PluginBasicBooleanResponse, + params={"page": 1, "page_size": 256}, + headers={ + "X-Plugin-ID": tool_provider_id.plugin_id, + "Content-Type": "application/json", + }, + + ) + + for resp in response: + return resp.result + + return False + def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 0e210c1389..778fcc94b7 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -20,3 +20,4 @@ class SystemVariableKey(StrEnum): DATASET_ID = "dataset_id" DATASOURCE_TYPE = "datasource_type" DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 198e167341..55d7ee5ccb 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -17,7 +17,6 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from models.workflow import WorkflowNodeExecutionStatus @@ -33,7 +32,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): _node_data_cls = DatasourceNodeData _node_type = NodeType.DATASOURCE - def _run(self) -> Generator: + def _run(self) -> NodeRunResult: """ Run the datasource node """ @@ -58,21 +57,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id=node_data.provider_id, - datasource_name=node_data.datasource_name, + datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) except DatasourceNodeError as e: - yield RunCompletedEvent( - run_result=NodeRunResult( + return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to get datasource runtime: {str(e)}", error_type=type(e).__name__, ) - ) - return + # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -99,66 +96,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): provider_type=datasource_type, ) ) - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "online_document": online_document_result.result.model_dump(), - "datasource_type": datasource_type, - }, - ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "online_document": online_document_result.result.model_dump(), + "datasource_type": datasource_type, + }, ) case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ "website": datasource_info, "datasource_type": datasource_type, - }, - ) + }, ) case DatasourceProviderType.LOCAL_FILE: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ "file": datasource_info, "datasource_type": datasource_runtime.datasource_provider_type, }, ) - ) case _: raise DatasourceNodeError( f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" ) except PluginDaemonClientSideError as e: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to transform datasource message: {str(e)}", - error_type=type(e).__name__, - ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", + error_type=type(e).__name__, ) except DatasourceNodeError as e: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to invoke datasource: {str(e)}", - error_type=type(e).__name__, - ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to invoke datasource: {str(e)}", + error_type=type(e).__name__, ) - return def _generate_parameters( self, @@ -225,18 +211,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): :return: """ result = {} - for parameter_name in node_data.datasource_parameters: - input = node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] + if input.type == "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + elif input.type == "variable": + result[parameter_name] = input.value + elif input.type == "constant": + pass - result = {node_id + "." + key: value for key, value in result.items()} + result = {node_id + "." + key: value for key, value in result.items()} return result diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 212184bb81..1f414ad0e2 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Union, Optional +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 803ecc765f..d883200c94 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -3,6 +3,7 @@ import logging from collections.abc import Mapping from typing import Any, cast +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables.segments import ObjectSegment @@ -10,16 +11,15 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus +from ..base import BaseNode from .entities import KnowledgeIndexNodeData from .exc import ( KnowledgeIndexNodeError, ) -from ..base import BaseNode logger = logging.getLogger(__name__) @@ -41,6 +41,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) + is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER if not isinstance(variable, ObjectSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -55,6 +56,13 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): ) # retrieve knowledge try: + if is_preview: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs={"result": "success"}, + ) results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool) outputs = {"result": results} return NodeRunResult( @@ -90,15 +98,15 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) if not batch: raise KnowledgeIndexNodeError("Batch is required.") - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") - document = Document.query.filter_by(id=document_id).first() + document = db.session.query(Document).filter_by(id=document_id).first() if not document: raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor() + index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) # update document status diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 62a16c56ce..e14a10680f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -270,7 +270,7 @@ class DatasetService: permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", runtime_mode="rag_pipeline", - icon_info=rag_pipeline_dataset_create_entity.icon_info, + icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), created_by=current_user.id, pipeline_id=pipeline.id, ) @@ -299,7 +299,7 @@ class DatasetService: permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", runtime_mode="rag-pipeline", - icon_info=rag_pipeline_dataset_create_entity.icon_info, + icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), ) with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 5f581f1360..800bd24021 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: list[str] - yaml_content: str + partial_member_list: Optional[list[str]] = None + yaml_content: Optional[str] = None class RerankingModelConfig(BaseModel): From 38cce3f62a1f47365bdb9df1cdfa65c297eb6b8a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 26 May 2025 14:52:09 +0800 Subject: [PATCH 041/155] r2 --- .../versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py index 89fcc6aa29..4d726cecb1 100644 --- a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'b35c3db83d09' -down_revision = 'd28f2004b072' +down_revision = '2adcbe1f5dfb' branch_labels = None depends_on = None From 1b07e612d2deaf638f504e94da12eb0d6f166a7a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 26 May 2025 15:49:37 +0800 Subject: [PATCH 042/155] r2 --- ...6_1659-abb18a379e62_add_pipeline_info_2.py | 80 ------------------- 1 file changed, 80 deletions(-) diff --git a/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py index 18e90e49dc..ae8e832d26 100644 --- a/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py +++ b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py @@ -19,14 +19,6 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('component_failure_stats') - op.drop_table('reliability_data') - op.drop_table('maintenance') - op.drop_table('operational_data') - op.drop_table('component_failure') - op.drop_table('tool_providers') - op.drop_table('safety_data') - op.drop_table('incident_data') with op.batch_alter_table('pipelines', schema=None) as batch_op: batch_op.drop_column('mode') @@ -38,76 +30,4 @@ def downgrade(): with op.batch_alter_table('pipelines', schema=None) as batch_op: batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False)) - op.create_table('incident_data', - sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False), - sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False), - sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey') - ) - op.create_table('safety_data', - sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False), - sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey') - ) - op.create_table('tool_providers', - sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), - sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), - sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), - sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), - sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), - sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') - ) - op.create_table('component_failure', - sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False), - sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'), - sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry') - ) - op.create_table('operational_data', - sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), - sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey') - ) - op.create_table('maintenance', - sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False), - sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey') - ) - op.create_table('reliability_data', - sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), - sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey') - ) - op.create_table('component_failure_stats', - sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), - sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey') - ) # ### end Alembic commands ### From ef0e41de07a7e7d3ae5de5fed3796f9bfbc2e502 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 26 May 2025 16:02:11 +0800 Subject: [PATCH 043/155] r2 --- .../console/auth/data_source_oauth.py | 23 +++++++++---------- .../app/apps/pipeline/pipeline_generator.py | 13 ++++++----- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 8da29093fd..0f3e2582c4 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from core.plugin.impl.datasource import PluginDatasourceManager +from core.plugin.impl.oauth import OAuthHandler from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -118,18 +118,17 @@ class DatasourcePluginOauthApi(Resource): # Check user role first if not current_user.is_editor: raise Forbidden() - # get all builtin providers - manager = PluginDatasourceManager() - # Fix: use correct method name or implement the missing method - try: - providers = manager.get_providers() # or whatever the correct method is - # Filter by datasource_type and datasource_name if needed - oauth_config = {} # Build appropriate OAuth URL response - return oauth_config - except AttributeError: - # Method doesn't exist, return empty response or implement - return {"oauth_url": None, "supported": False} + oauth_handler = OAuthHandler() + providers = oauth_handler.get_authorization_url( + current_user.current_tenant.id, + current_user.id, + datasource_type, + datasource_name, + system_credentials={} + ) + return providers + api.add_resource(OAuthDataSource, "/oauth/data-source/") diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 55a8b96d06..120927516b 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -22,13 +22,13 @@ from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline from extensions.ext_database import db from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline @@ -108,16 +108,17 @@ class PipelineGenerator(BaseAppGenerator): for datasource_info in datasource_info_list: workflow_run_id = str(uuid.uuid4()) document_id = None - + # Add null check for dataset - if not pipeline.dataset: + dataset = pipeline.dataset + if not dataset: raise ValueError("Pipeline dataset is required") - + if invoke_from == InvokeFrom.PUBLISHED: - position = DocumentService.get_documents_position(pipeline.dataset_id) + position = DocumentService.get_documents_position(dataset.id) document = self._build_document( tenant_id=pipeline.tenant_id, - dataset_id=pipeline.dataset_id, + dataset_id=dataset.id, built_in_field_enabled=pipeline.dataset.built_in_field_enabled, datasource_type=datasource_type, datasource_info=datasource_info, From 678d6ffe2bb76f953fa38900239ede56fbf301c3 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 26 May 2025 17:00:16 +0800 Subject: [PATCH 044/155] r2 --- .../console/auth/data_source_oauth.py | 21 ------ api/controllers/console/auth/oauth.py | 67 ++++++++++++++++++- .../datasets/rag_pipeline/datasource_oauth.py | 56 ++++++++++++++++ .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 4 +- api/models/oauth.py | 47 +++++++++++++ 6 files changed, 171 insertions(+), 26 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_oauth.py create mode 100644 api/models/oauth.py diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0f3e2582c4..5299064e17 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -8,7 +8,6 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from core.plugin.impl.oauth import OAuthHandler from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -110,29 +109,9 @@ class OAuthDataSourceSync(Resource): return {"result": "success"}, 200 -class DatasourcePluginOauthApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, datasource_type, datasource_name): - # Check user role first - if not current_user.is_editor: - raise Forbidden() - # get all builtin providers - oauth_handler = OAuthHandler() - providers = oauth_handler.get_authorization_url( - current_user.current_tenant.id, - current_user.id, - datasource_type, - datasource_name, - system_credentials={} - ) - return providers - api.add_resource(OAuthDataSource, "/oauth/data-source/") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") -api.add_resource(DatasourcePluginOauthApi, "/oauth/plugin/datasource//") \ No newline at end of file diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 395367c9e2..d8576c3879 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -2,26 +2,30 @@ import logging from datetime import UTC, datetime from typing import Optional +from flask_login import current_user import requests from flask import current_app, redirect, request from flask_restful import Resource from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import Unauthorized +from werkzeug.exceptions import Unauthorized, Forbidden, NotFound from configs import dify_config from constants.languages import languages +from controllers.console.wraps import account_initialization_required, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip +from libs.login import login_required from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService - +from core.plugin.impl.oauth import OAuthHandler from .. import api @@ -181,5 +185,64 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account +class PluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all plugin oauth configs + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + oauth_handler = OAuthHandler() + response = oauth_handler.get_authorization_url( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials + ) + return response.model_dump() + +class PluginOauthCallback(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + oauth_handler = OAuthHandler() + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + credentials = oauth_handler.get_credentials( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials, + request=request + ) + datasource_provider = DatasourceProvider( + datasource_name=plugin_oauth_config.datasource_name, + plugin_id=plugin_id, + provider=provider, + auth_type="oauth", + encrypted_credentials=credentials + ) + db.session.add(datasource_provider) + db.session.commit() + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + + api.add_resource(OAuthLogin, "/oauth/login/") api.add_resource(OAuthCallback, "/oauth/authorize/") +api.add_resource(PluginOauthApi, "/oauth/plugin/provider//plugin/") +api.add_resource(PluginOauthCallback, "/oauth/plugin/callback//plugin/") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py new file mode 100644 index 0000000000..b105606c04 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py @@ -0,0 +1,56 @@ +from typing import cast + +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from core.plugin.impl.datasource import PluginDatasourceManager +from extensions.ext_database import db +from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from libs.login import login_required +from models import Account +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + + +class DatasourcePluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, datasource_type, datasource_name): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all builtin providers + manager = PluginDatasourceManager() + providers = manager.get_provider_oauth_url() + return providers + + + + +# Import Rag Pipeline +api.add_resource( + DatasourcePluginOauthApi, + "/datasource///oauth", +) +api.add_resource( + RagPipelineImportConfirmApi, + "/rag/pipelines/imports//confirm", +) +api.add_resource( + RagPipelineImportCheckDependenciesApi, + "/rag/pipelines/imports//check-dependencies", +) +api.add_resource( + RagPipelineExportApi, + "/rag/pipelines//exports", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 63fe4b7f87..bbeaa33341 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -279,7 +279,7 @@ class PublishedRagPipelineRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") - parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) args = parser.parse_args() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index dd9eade0a5..23dbfef70d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -104,7 +104,7 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, - SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from, + SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, } variable_pool = VariablePool( @@ -199,4 +199,4 @@ class PipelineRunner(WorkflowBasedAppRunner): if not graph: raise ValueError("graph not found in workflow") - return graph \ No newline at end of file + return graph diff --git a/api/models/oauth.py b/api/models/oauth.py new file mode 100644 index 0000000000..f24c3c6723 --- /dev/null +++ b/api/models/oauth.py @@ -0,0 +1,47 @@ + +from datetime import datetime +from json import JSONDecodeError +from typing import Any, cast + +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped + +from configs import dify_config +from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + +from .account import Account +from .base import Base +from .engine import db +from .model import App, Tag, TagBinding, UploadFile +from .types import StringUUID + + +class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] + __tablename__ = "datasource_oauth_params" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + +class DatasourceProvider(Base): + __tablename__ = "datasource_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), + ) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + provider: Mapped[str] = db.Column(db.String(255), nullable=False) + auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) + encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) From 83ca7f8deba25188bc6ef9d6365044e2aff8b1dd Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 26 May 2025 17:32:25 +0800 Subject: [PATCH 045/155] feat: add datasource support to PluginDeclaration and PluginCategory --- api/controllers/console/auth/oauth.py | 7 ++++--- .../console/datasets/rag_pipeline/datasource_oauth.py | 11 +---------- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- .../website_crawl/website_crawl_provider.py | 6 +++++- api/core/plugin/entities/plugin.py | 6 ++++++ api/core/plugin/impl/datasource.py | 2 +- api/core/workflow/nodes/datasource/datasource_node.py | 2 +- api/models/oauth.py | 9 --------- 8 files changed, 19 insertions(+), 26 deletions(-) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d8576c3879..d5e13525d6 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -2,17 +2,18 @@ import logging from datetime import UTC, datetime from typing import Optional -from flask_login import current_user import requests from flask import current_app, redirect, request +from flask_login import current_user from flask_restful import Resource from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import Unauthorized, Forbidden, NotFound +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from configs import dify_config from constants.languages import languages from controllers.console.wraps import account_initialization_required, setup_required +from core.plugin.impl.oauth import OAuthHandler from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip @@ -25,7 +26,7 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService -from core.plugin.impl.oauth import OAuthHandler + from .. import api diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py index b105606c04..f4164dea7b 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py @@ -1,24 +1,15 @@ -from typing import cast from flask_login import current_user # type: ignore -from flask_restful import Resource, marshal_with, reqparse # type: ignore -from sqlalchemy.orm import Session +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, setup_required, ) from core.plugin.impl.datasource import PluginDatasourceManager -from extensions.ext_database import db -from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields from libs.login import login_required -from models import Account -from models.dataset import Pipeline -from services.app_dsl_service import ImportStatus -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class DatasourcePluginOauthApi(Resource): diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 120927516b..9c25f8f4e6 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -30,7 +30,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db -from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline from models.model import AppMode from services.dataset_service import DocumentService diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index d9043702d2..11168b4c26 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -10,7 +10,11 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon plugin_unique_identifier: str def __init__( - self, entity: DatasourceProviderEntityWithPlugin | None, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + self, + entity: DatasourceProviderEntityWithPlugin | None, + plugin_id: str, + plugin_unique_identifier: str, + tenant_id: str, ) -> None: super().__init__(entity, tenant_id) self.plugin_id = plugin_id diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 260d4f12db..e2ea9669fa 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator from werkzeug.exceptions import NotFound from core.agent.plugin_entities import AgentStrategyProviderEntity +from core.datasource.entities.datasource_entities import DatasourceProviderEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration @@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum): Model = "model" Extension = "extension" AgentStrategy = "agent-strategy" + Datasource = "datasource" class PluginDeclaration(BaseModel): @@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel): tools: Optional[list[str]] = Field(default_factory=list[str]) models: Optional[list[str]] = Field(default_factory=list[str]) endpoints: Optional[list[str]] = Field(default_factory=list[str]) + datasources: Optional[list[str]] = Field(default_factory=list[str]) class Meta(BaseModel): minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") @@ -90,6 +93,7 @@ class PluginDeclaration(BaseModel): model: Optional[ProviderEntity] = None endpoint: Optional[EndpointProviderDeclaration] = None agent_strategy: Optional[AgentStrategyProviderEntity] = None + datasource: Optional[DatasourceProviderEntity] = None meta: Meta @model_validator(mode="before") @@ -100,6 +104,8 @@ class PluginDeclaration(BaseModel): values["category"] = PluginCategory.Tool elif values.get("model"): values["category"] = PluginCategory.Model + elif values.get("datasource"): + values["category"] = PluginCategory.Datasource elif values.get("agent_strategy"): values["category"] = PluginCategory.AgentStrategy else: diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 2ad5bdcfef..004cf7f9c3 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -236,7 +236,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "GET", - f"plugin/datasource/oauth", + "plugin/datasource/oauth", PluginBasicBooleanResponse, params={"page": 1, "page_size": 256}, headers={ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 55d7ee5ccb..2f15fba3af 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Mapping, Sequence from typing import Any, cast from core.datasource.entities.datasource_entities import ( diff --git a/api/models/oauth.py b/api/models/oauth.py index f24c3c6723..aee45d7c41 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,20 +1,11 @@ from datetime import datetime -from json import JSONDecodeError -from typing import Any, cast -from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped -from configs import dify_config -from extensions.ext_storage import storage -from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule - -from .account import Account from .base import Base from .engine import db -from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID From 5fc2bc58a9fba17c812837e208b948f8bd322e2d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 27 May 2025 00:01:23 +0800 Subject: [PATCH 046/155] r2 --- api/controllers/console/auth/oauth.py | 58 ------- .../datasets/rag_pipeline/datasource_auth.py | 140 ++++++++++++++++ .../datasets/rag_pipeline/datasource_oauth.py | 47 ------ api/core/plugin/impl/datasource.py | 33 +--- .../index_processor/index_processor_base.py | 4 +- .../processor/paragraph_index_processor.py | 4 +- .../processor/parent_child_index_processor.py | 4 +- .../knowledge_index/knowledge_index_node.py | 6 - api/models/oauth.py | 2 - api/services/datasource_provider_service.py | 150 ++++++++++++++++++ 10 files changed, 301 insertions(+), 147 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_auth.py delete mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_oauth.py create mode 100644 api/services/datasource_provider_service.py diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d5e13525d6..ed595f5d3d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -186,64 +186,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -class PluginOauthApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, provider, plugin_id): - # Check user role first - if not current_user.is_editor: - raise Forbidden() - # get all plugin oauth configs - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() - if not plugin_oauth_config: - raise NotFound() - oauth_handler = OAuthHandler() - response = oauth_handler.get_authorization_url( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, - system_credentials=plugin_oauth_config.system_credentials - ) - return response.model_dump() - -class PluginOauthCallback(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, provider, plugin_id): - oauth_handler = OAuthHandler() - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() - if not plugin_oauth_config: - raise NotFound() - credentials = oauth_handler.get_credentials( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, - system_credentials=plugin_oauth_config.system_credentials, - request=request - ) - datasource_provider = DatasourceProvider( - datasource_name=plugin_oauth_config.datasource_name, - plugin_id=plugin_id, - provider=provider, - auth_type="oauth", - encrypted_credentials=credentials - ) - db.session.add(datasource_provider) - db.session.commit() - return redirect(f"{dify_config.CONSOLE_WEB_URL}") - api.add_resource(OAuthLogin, "/oauth/login/") api.add_resource(OAuthCallback, "/oauth/authorize/") -api.add_resource(PluginOauthApi, "/oauth/plugin/provider//plugin/") -api.add_resource(PluginOauthCallback, "/oauth/plugin/callback//plugin/") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py new file mode 100644 index 0000000000..8894babcf7 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -0,0 +1,140 @@ + +from flask import redirect, request +from flask_login import current_user # type: ignore +from flask_restful import ( # type: ignore + Resource, # type: ignore + marshal_with, + reqparse, +) +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.impl.datasource import PluginDatasourceManager +from core.plugin.impl.oauth import OAuthHandler +from extensions.ext_database import db +from libs.login import login_required +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider +from services.datasource_provider_service import DatasourceProviderService + + +class DatasourcePluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all plugin oauth configs + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + oauth_handler = OAuthHandler() + redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback" + system_credentials = plugin_oauth_config.system_credentials + if system_credentials: + system_credentials["redirect_url"] = redirect_url + response = oauth_handler.get_authorization_url( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=system_credentials + ) + return response.model_dump() + +class DatasourceOauthCallback(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + oauth_handler = OAuthHandler() + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + credentials = oauth_handler.get_credentials( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials, + request=request + ) + datasource_provider = DatasourceProvider( + plugin_id=plugin_id, + provider=provider, + auth_type="oauth", + encrypted_credentials=credentials + ) + db.session.add(datasource_provider) + db.session.commit() + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + +class DatasourceAuth(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider, plugin_id): + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + datasource_provider_service = DatasourceProviderService() + + try: + datasource_provider_service.datasource_provider_credentials_validate( + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id, + credentials=args["credentials"] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + +class DatasourceAuthDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, provider, plugin_id): + if not current_user.is_editor: + raise Forbidden() + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.remove_datasource_credentials( + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id + ) + return {"result": "success"}, 200 + +# Import Rag Pipeline +api.add_resource( + DatasourcePluginOauthApi, + "/oauth/datasource/provider//plugin/", +) +api.add_resource( + DatasourceOauthCallback, + "/oauth/datasource/provider//plugin//callback", +) +api.add_resource( + DatasourceAuth, + "/auth/datasource/provider//plugin/", +) + diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py deleted file mode 100644 index f4164dea7b..0000000000 --- a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py +++ /dev/null @@ -1,47 +0,0 @@ - -from flask_login import current_user # type: ignore -from flask_restful import Resource # type: ignore -from werkzeug.exceptions import Forbidden - -from controllers.console import api -from controllers.console.wraps import ( - account_initialization_required, - setup_required, -) -from core.plugin.impl.datasource import PluginDatasourceManager -from libs.login import login_required - - -class DatasourcePluginOauthApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, datasource_type, datasource_name): - # Check user role first - if not current_user.is_editor: - raise Forbidden() - # get all builtin providers - manager = PluginDatasourceManager() - providers = manager.get_provider_oauth_url() - return providers - - - - -# Import Rag Pipeline -api.add_resource( - DatasourcePluginOauthApi, - "/datasource///oauth", -) -api.add_resource( - RagPipelineImportConfirmApi, - "/rag/pipelines/imports//confirm", -) -api.add_resource( - RagPipelineImportCheckDependenciesApi, - "/rag/pipelines/imports//check-dependencies", -) -api.add_resource( - RagPipelineExportApi, - "/rag/pipelines//exports", -) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 004cf7f9c3..b5212eb719 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -203,7 +203,7 @@ class PluginDatasourceManager(BasePluginClient): """ validate the credentials of the provider """ - tool_provider_id = GenericProviderID(provider) + datasource_provider_id = GenericProviderID(provider) response = self._request_with_plugin_daemon_response_stream( "POST", @@ -212,12 +212,12 @@ class PluginDatasourceManager(BasePluginClient): data={ "user_id": user_id, "data": { - "provider": tool_provider_id.provider_name, + "provider": datasource_provider_id.provider_name, "credentials": credentials, }, }, headers={ - "X-Plugin-ID": tool_provider_id.plugin_id, + "X-Plugin-ID": datasource_provider_id.plugin_id, "Content-Type": "application/json", }, @@ -227,34 +227,11 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False - - def get_provider_oauth_url(self, datasource_type: str, datasource_name: str, provider: str) -> str: - """ - get the oauth url of the provider - """ - tool_provider_id = GenericProviderID(provider) - response = self._request_with_plugin_daemon_response_stream( - "GET", - "plugin/datasource/oauth", - PluginBasicBooleanResponse, - params={"page": 1, "page_size": 256}, - headers={ - "X-Plugin-ID": tool_provider_id.plugin_id, - "Content-Type": "application/json", - }, - - ) - - for resp in response: - return resp.result - - return False - def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", - "plugin_id": "langgenius/file", + "plugin_id": "langgenius/file/file", "provider": "langgenius", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", "declaration": { @@ -280,7 +257,7 @@ class PluginDatasourceManager(BasePluginClient): "datasources": [{ "identity": { "author": "langgenius", - "name": "local_file", + "name": "upload-file", "provider": "langgenius", "label": { "zh_Hans": "File", diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d796c9fd24..50511de16f 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,7 +13,7 @@ from core.rag.splitter.fixed_text_splitter import ( FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule class BaseIndexProcessor(ABC): @@ -35,7 +35,7 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): raise NotImplementedError @abstractmethod diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 155aae61d4..5eab77d4f8 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -15,7 +15,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule from services.entities.knowledge_entities.knowledge_entities import Rule @@ -128,7 +128,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): paragraph = GeneralStructureChunk(**chunks) documents = [] for content in paragraph.general_chunk: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 5279864441..6300d05707 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -16,7 +16,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -205,7 +205,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_nodes.append(child_document) return child_nodes - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): parent_childs = ParentChildStructureChunk(**chunks) documents = [] for parent_child in parent_childs.parent_child_chunks: diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d883200c94..25a4112998 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -42,12 +42,6 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER - if not isinstance(variable, ObjectSegment): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - error="Index chunk variable is not object type.", - ) chunks = variable.value variables = {"chunks": chunks} if not chunks: diff --git a/api/models/oauth.py b/api/models/oauth.py index aee45d7c41..fefe743195 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -17,7 +17,6 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) @@ -29,7 +28,6 @@ class DatasourceProvider(Base): db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py new file mode 100644 index 0000000000..fbb9b25a75 --- /dev/null +++ b/api/services/datasource_provider_service.py @@ -0,0 +1,150 @@ +import logging +from typing import Optional + +from flask_login import current_user + +from constants import HIDDEN_VALUE +from core import datasource +from core.datasource.__base import datasource_provider +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.helper import encrypter +from core.model_runtime.entities.model_entities import ModelType, ParameterRule +from core.model_runtime.entities.provider_entities import FormType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.datasource import PluginDatasourceManager +from core.provider_manager import ProviderManager +from models.oauth import DatasourceProvider +from models.provider import ProviderType +from services.entities.model_provider_entities import ( + CustomConfigurationResponse, + CustomConfigurationStatus, + DefaultModelResponse, + ModelWithProviderEntityResponse, + ProviderResponse, + ProviderWithModelsResponse, + SimpleProviderEntityResponse, + SystemConfigurationResponse, +) +from extensions.database import db + +logger = logging.getLogger(__name__) + + +class DatasourceProviderService: + """ + Model Provider Service + """ + + def __init__(self) -> None: + self.provider_manager = PluginDatasourceManager() + + def datasource_provider_credentials_validate(self, + tenant_id: str, + provider: str, + plugin_id: str, + credentials: dict) -> None: + """ + validate datasource provider credentials. + + :param tenant_id: + :param provider: + :param credentials: + """ + credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + credentials=credentials) + if credential_valid: + # Get all provider configurations of the current workspace + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id).first() + + provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, + provider=provider + ) + if not datasource_provider: + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials) + db.session.add(datasource_provider) + db.session.commit() + else: + original_credentials = datasource_provider.encrypted_credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) + credentials[key] = encrypter.encrypt_token(tenant_id, original_value) + else: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + datasource_provider.encrypted_credentials = credentials + db.session.commit() + else: + raise CredentialsValidateFailedError() + + def extract_secret_variables(self, tenant_id: str, provider: str) -> list[str]: + """ + Extract secret input form variables. + + :param credential_form_schemas: + :return: + """ + datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, provider=provider) + credential_form_schemas = datasource_provider.declaration.credentials_schema + secret_input_form_variables = [] + for credential_form_schema in credential_form_schemas: + if credential_form_schema.type == FormType.SECRET_INPUT: + secret_input_form_variables.append(credential_form_schema.name) + + return secret_input_form_variables + + + def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> Optional[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param datasource_name: datasource name + :param plugin_id: plugin id + :return: + """ + # Get all provider configurations of the current workspace + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id).first() + + + + + def remove_datasource_credentials(self, + tenant_id: str, + provider: str, + plugin_id: str) -> None: + """ + remove datasource credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param datasource_name: datasource name + :param plugin_id: plugin id + :return: + """ + # Get all provider configurations of the current workspace + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id).first() + if datasource_provider: + db.session.delete(datasource_provider) + db.session.commit() + From 7f59ffe7afdbcc9490c2b5481d909ad31d6f253a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 28 May 2025 17:56:04 +0800 Subject: [PATCH 047/155] r2 --- .../rag_pipeline/rag_pipeline_datasets.py | 3 +- .../rag_pipeline/rag_pipeline_workflow.py | 24 ++- .../app/apps/pipeline/pipeline_generator.py | 84 +++++++-- api/core/app/apps/pipeline/pipeline_runner.py | 8 + .../datasource/__base/datasource_provider.py | 4 +- .../datasource/__base/datasource_runtime.py | 4 +- .../website_crawl/website_crawl_provider.py | 2 +- api/core/plugin/impl/datasource.py | 24 +-- .../index_processor/index_processor_base.py | 7 +- .../processor/paragraph_index_processor.py | 13 +- .../processor/parent_child_index_processor.py | 16 ++ .../processor/qa_index_processor.py | 10 +- api/core/rag/models/document.py | 2 +- api/core/variables/variables.py | 19 +- api/core/workflow/constants.py | 2 +- api/core/workflow/entities/variable_pool.py | 17 +- .../entities/workflow_execution_entities.py | 1 + .../workflow/graph_engine/graph_engine.py | 4 +- .../nodes/datasource/datasource_node.py | 62 ++++++- .../workflow/nodes/datasource/entities.py | 2 +- .../knowledge_index/knowledge_index_node.py | 48 +++-- api/fields/workflow_fields.py | 3 +- api/models/enums.py | 2 + api/models/workflow.py | 5 +- api/services/dataset_service.py | 172 +++++++++++++----- api/services/datasource_provider_service.py | 35 ++-- .../rag_pipeline_entities.py | 9 + .../rag_pipeline/pipeline_generate_service.py | 6 +- .../database/database_retrieval.py | 16 +- api/services/rag_pipeline/rag_pipeline.py | 27 ++- .../rag_pipeline/rag_pipeline_dsl_service.py | 80 ++++---- api/tasks/deal_dataset_index_update_task.py | 171 +++++++++++++++++ 32 files changed, 680 insertions(+), 202 deletions(-) create mode 100644 api/tasks/deal_dataset_index_update_task.py diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 1a4e9240b6..f502157eda 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -15,6 +15,7 @@ from libs.login import login_required from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService def _validate_name(name): @@ -91,7 +92,7 @@ class CreateRagPipelineDatasetApi(Resource): raise Forbidden() rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args) try: - import_info = DatasetService.create_rag_pipeline_dataset( + import_info = RagPipelineDslService.create_rag_pipeline_dataset( tenant_id=current_user.current_tenant_id, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index bbeaa33341..fc6eab529a 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -40,6 +40,7 @@ from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline from models.model import EndUser +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService @@ -282,15 +283,18 @@ class PublishedRagPipelineRunApi(Resource): parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) + parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") args = parser.parse_args() + streaming = args["response_mode"] == "streaming" + try: response = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, - streaming=True, + streaming=streaming, ) return helper.compact_generate_response(response) @@ -459,16 +463,17 @@ class PublishedRagPipelineApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument("marked_name", type=str, required=False, default="", location="json") - parser.add_argument("marked_comment", type=str, required=False, default="", location="json") + parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.") args = parser.parse_args() - # Validate name and comment length - if args.marked_name and len(args.marked_name) > 20: - raise ValueError("Marked name cannot exceed 20 characters") - if args.marked_comment and len(args.marked_comment) > 100: - raise ValueError("Marked comment cannot exceed 100 characters") + if not args.get("knowledge_base_setting"): + raise ValueError("Missing knowledge base setting.") + knowledge_base_setting_data = args.get("knowledge_base_setting") + if not knowledge_base_setting_data: + raise ValueError("Missing knowledge base setting.") + + knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data) rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: pipeline = session.merge(pipeline) @@ -476,8 +481,7 @@ class PublishedRagPipelineApi(Resource): session=session, pipeline=pipeline, account=current_user, - marked_name=args.marked_name or "", - marked_comment=args.marked_comment or "", + knowledge_base_setting=knowledge_base_setting, ) pipeline.is_published = True pipeline.workflow_id = workflow.id diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 9c25f8f4e6..e4c96775c8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -28,10 +28,13 @@ from core.app.entities.task_entities import WorkflowAppBlockingResponse, Workflo from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline +from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.dataset_service import DocumentService @@ -51,7 +54,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Generator[Mapping | str, None, None]: ... + ) -> Generator[Mapping | str, None, None] | None: ... @overload def generate( @@ -92,7 +95,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, @@ -119,14 +122,14 @@ class PipelineGenerator(BaseAppGenerator): document = self._build_document( tenant_id=pipeline.tenant_id, dataset_id=dataset.id, - built_in_field_enabled=pipeline.dataset.built_in_field_enabled, + built_in_field_enabled=dataset.built_in_field_enabled, datasource_type=datasource_type, datasource_info=datasource_info, created_from="rag-pipeline", position=position, account=user, batch=batch, - document_form=pipeline.dataset.chunk_structure, + document_form=dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -138,7 +141,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline_config=pipeline_config, datasource_type=datasource_type, datasource_info=datasource_info, - dataset_id=pipeline.dataset.id, + dataset_id=dataset.id, start_node_id=start_node_id, batch=batch, document_id=document_id, @@ -159,15 +162,24 @@ class PipelineGenerator(BaseAppGenerator): contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, ) if invoke_from == InvokeFrom.DEBUGGER: return self._generate( @@ -176,6 +188,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, @@ -187,6 +200,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, @@ -200,6 +214,7 @@ class PipelineGenerator(BaseAppGenerator): user: Union[Account, EndUser], application_generate_entity: RagPipelineGenerateEntity, invoke_from: InvokeFrom, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, @@ -207,11 +222,12 @@ class PipelineGenerator(BaseAppGenerator): """ Generate App response. - :param app_model: App + :param pipeline: Pipeline :param workflow: Workflow :param user: account or end user :param application_generate_entity: application generate entity :param invoke_from: invoke from source + :param workflow_execution_repository: repository for workflow execution :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id @@ -244,6 +260,7 @@ class PipelineGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, ) @@ -276,16 +293,20 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") # init application generate entity - use RagPipelineGenerateEntity instead application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=app_config, - pipeline_config=app_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, datasource_type=args.get("datasource_type", ""), datasource_info=args.get("datasource_info", {}), - dataset_id=pipeline.dataset_id, + dataset_id=dataset.id, batch=args.get("batch", ""), document_id=args.get("document_id"), inputs={}, @@ -299,10 +320,16 @@ class PipelineGenerator(BaseAppGenerator): contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, + ) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -316,6 +343,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -345,20 +373,30 @@ class PipelineGenerator(BaseAppGenerator): if args.get("inputs") is None: raise ValueError("inputs is required") + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") + # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( + application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=app_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=args.get("datasource_type", ""), + datasource_info=args.get("datasource_info", {}), + batch=args.get("batch", ""), + document_id=args.get("document_id"), + dataset_id=dataset.id, inputs={}, files=[], user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, - single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), workflow_run_id=str(uuid.uuid4()), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -368,6 +406,13 @@ class PipelineGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, + ) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -381,6 +426,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -438,6 +484,7 @@ class PipelineGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -459,6 +506,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, stream=stream, workflow_node_execution_repository=workflow_node_execution_repository, + workflow_execution_repository=workflow_execution_repository, ) try: @@ -481,7 +529,7 @@ class PipelineGenerator(BaseAppGenerator): datasource_info: Mapping[str, Any], created_from: str, position: int, - account: Account, + account: Union[Account, EndUser], batch: str, document_form: str, ): diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 23dbfef70d..8d90e7ee3e 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, ) +from core.variables.variables import RAGPipelineVariable from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -106,12 +107,19 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, } + rag_pipeline_variables = {} + if workflow.rag_pipeline_variables: + for v in workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: + rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, conversation_variables=[], + rag_pipeline_variables=rag_pipeline_variables, ) # init graph diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index 045ca64872..bae39dc8c7 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError class DatasourcePluginProviderController(ABC): - entity: DatasourceProviderEntityWithPlugin | None + entity: DatasourceProviderEntityWithPlugin tenant_id: str - def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: self.entity = entity self.tenant_id = tenant_id diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index 51ff1fc6c1..9ddc25a637 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -14,9 +14,9 @@ class DatasourceRuntime(BaseModel): """ tenant_id: str - tool_id: Optional[str] = None + datasource_id: Optional[str] = None invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None + datasource_invoke_from: Optional[DatasourceInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 11168b4c26..8c0f20ce2d 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -11,7 +11,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon def __init__( self, - entity: DatasourceProviderEntityWithPlugin | None, + entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str, diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index b5212eb719..7847218bb9 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -30,22 +30,16 @@ class PluginDatasourceManager(BasePluginClient): return json_response - # response = self._request_with_plugin_daemon_response( - # "GET", - # f"plugin/{tenant_id}/management/datasources", - # list[PluginDatasourceProviderEntity], - # params={"page": 1, "page_size": 256}, - # transformer=transformer, - # ) + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + list[PluginDatasourceProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - # for provider in response: - # provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - - # # override the provider name for each tool to plugin_id/provider_name - # for datasource in provider.declaration.datasources: - # datasource.identity.provider = provider.declaration.identity.name - - return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())] + return [local_file_datasource_provider] + response def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 50511de16f..72e4923b58 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,7 +13,8 @@ from core.rag.splitter.fixed_text_splitter import ( FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Document as DatasetDocument class BaseIndexProcessor(ABC): @@ -37,6 +38,10 @@ class BaseIndexProcessor(ABC): @abstractmethod def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): raise NotImplementedError + + @abstractmethod + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + raise NotImplementedError @abstractmethod def retrieve( diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5eab77d4f8..559bc5d59b 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -131,7 +131,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): paragraph = GeneralStructureChunk(**chunks) documents = [] - for content in paragraph.general_chunk: + for content in paragraph.general_chunks: metadata = { "dataset_id": dataset.id, "document_id": document.id, @@ -151,3 +151,14 @@ class ParagraphIndexProcessor(BaseIndexProcessor): elif dataset.indexing_technique == "economy": keyword = Keyword(dataset) keyword.add_texts(documents) + + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + paragraph = GeneralStructureChunk(**chunks) + preview = [] + for content in paragraph.general_chunks: + preview.append({"content": content}) + return { + "preview": preview, + "total_segments": len(paragraph.general_chunks) + } \ No newline at end of file diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 6300d05707..7a3f8f1c63 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -234,3 +234,19 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + parent_childs = ParentChildStructureChunk(**chunks) + preview = [] + for parent_child in parent_childs.parent_child_chunks: + preview.append( + { + "content": parent_child.parent_content, + "child_chunks": parent_child.child_contents + + } + ) + return { + "preview": preview, + "total_segments": len(parent_childs.parent_child_chunks) + } \ No newline at end of file diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0055625e13..b415596254 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,7 @@ import logging import re import threading import uuid -from typing import Optional +from typing import Any, Mapping, Optional import pandas as pd from flask import Flask, current_app @@ -20,7 +20,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset +from models.dataset import Dataset, Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -160,6 +160,12 @@ class QAIndexProcessor(BaseIndexProcessor): doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs + + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + pass + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + return {"preview": chunks} def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 52795bbadf..9f0054a165 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -40,7 +40,7 @@ class GeneralStructureChunk(BaseModel): General Structure Chunk. """ - general_chunk: list[str] + general_chunks: list[str] class ParentChildChunk(BaseModel): diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682e..c0952383a9 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import cast from uuid import uuid4 -from pydantic import Field +from pydantic import BaseModel, Field from core.helper import encrypter @@ -93,3 +93,20 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + +class RAGPipelineVariable(BaseModel): + belong_to_node_id: str = Field(description="belong to which node id, shared means public") + type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") + label: str = Field(description="label") + description: str | None = Field(description="description", default="") + variable: str = Field(description="variable key", default="") + max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) + default_value: str | None = Field(description="default value", default="") + placeholder: str | None = Field(description="placeholder", default="") + unit: str | None = Field(description="unit, applicable to Number", default="") + tooltips: str | None = Field(description="helpful text", default="") + allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) + allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) + allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) + required: bool = Field(description="optional, default false", default=False) + options: list[str] | None = Field(default_factory=list) diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py index 59edcee456..7664be0983 100644 --- a/api/core/workflow/constants.py +++ b/api/core/workflow/constants.py @@ -1,4 +1,4 @@ SYSTEM_VARIABLE_NODE_ID = "sys" ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" -PIPELINE_VARIABLE_NODE_ID = "pipeline" +RAG_PIPELINE_VARIABLE_NODE_ID = "rag" diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index af26864c01..319833145e 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -10,7 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable from core.variables.segments import FileSegment, NoneSegment from factories import variable_factory -from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from ..constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from ..enums import SystemVariableKey VariableValue = Union[str, int, float, dict, list, File] @@ -42,6 +47,10 @@ class VariablePool(BaseModel): description="Conversation variables.", default_factory=list, ) + rag_pipeline_variables: Mapping[str, Any] = Field( + description="RAG pipeline variables.", + default_factory=dict, + ) def __init__( self, @@ -50,18 +59,21 @@ class VariablePool(BaseModel): user_inputs: Mapping[str, Any] | None = None, environment_variables: Sequence[Variable] | None = None, conversation_variables: Sequence[Variable] | None = None, + rag_pipeline_variables: Mapping[str, Any] | None = None, **kwargs, ): environment_variables = environment_variables or [] conversation_variables = conversation_variables or [] user_inputs = user_inputs or {} system_variables = system_variables or {} + rag_pipeline_variables = rag_pipeline_variables or {} super().__init__( system_variables=system_variables, user_inputs=user_inputs, environment_variables=environment_variables, conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, **kwargs, ) @@ -73,6 +85,9 @@ class VariablePool(BaseModel): # Add conversation variables to the variable pool for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + # Add rag pipeline variables to the variable pool + for var, value in self.rag_pipeline_variables.items(): + self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value) def add(self, selector: Sequence[str], value: Any, /) -> None: """ diff --git a/api/core/workflow/entities/workflow_execution_entities.py b/api/core/workflow/entities/workflow_execution_entities.py index 200d4697b5..28fae53ced 100644 --- a/api/core/workflow/entities/workflow_execution_entities.py +++ b/api/core/workflow/entities/workflow_execution_entities.py @@ -20,6 +20,7 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" + RAG_PIPELINE = "rag-pipeline" class WorkflowExecutionStatus(StrEnum): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 36273d8ec1..c17f1eeb2b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -173,7 +173,7 @@ class GraphEngine: ) return elif isinstance(item, NodeRunSucceededEvent): - if item.node_type == NodeType.END: + if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX): self.graph_runtime_state.outputs = ( dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result @@ -319,7 +319,7 @@ class GraphEngine: # It may not be necessary, but it is necessary. :) if ( self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() - == NodeType.END.value + in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] ): break diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 2f15fba3af..8f841f9564 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.file import File from core.plugin.impl.exc import PluginDaemonClientSideError -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from models.model import UploadFile from models.workflow import WorkflowNodeExecutionStatus from .entities import DatasourceNodeData @@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): provider_id=node_data.provider_id, datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), + datasource_type=DatasourceProviderType.value_of(datasource_type), ) except DatasourceNodeError as e: return NodeRunResult( @@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): error=f"Failed to get datasource runtime: {str(e)}", error_type=type(e).__name__, ) - + # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): "datasource_type": datasource_type, }, ) - case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE: + case DatasourceProviderType.WEBSITE_CRAWL: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, @@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): }, ) case DatasourceProviderType.LOCAL_FILE: + upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() + if not upload_file: + raise ValueError("Invalid upload file Info") + + file_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=datasource_info.get("type", ""), + transfer_method=datasource_info.get("transfer_method", ""), + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) + for key, value in datasource_info.items(): + # construct new key list + new_key_list = ["file", key] + self._append_variables_recursively( + variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "file": datasource_info, - "datasource_type": datasource_runtime.datasource_provider_type, + "file_info": file_info, + "datasource_type": datasource_type, }, ) case _: raise DatasourceNodeError( - f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" + f"Unsupported datasource provider: {datasource_type}" ) except PluginDaemonClientSideError as e: return NodeRunResult( @@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] + + + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + variable_pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 1f414ad0e2..dee3c1d2fb 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel): class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceInput(BaseModel): # TODO: check this type - value: Optional[Union[Any, list[str]]] = None + value: Union[Any, list[str]] type: Optional[Literal["mixed", "variable", "constant"]] = None @field_validator("type", mode="before") diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 25a4112998..fef434e3ec 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -39,15 +39,30 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeIndexNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool + dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + if not dataset_id: + raise KnowledgeIndexNodeError("Dataset ID is required.") + dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.") + # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) - is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER + if not variable: + raise KnowledgeIndexNodeError("Index chunk variable is required.") + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + if invoke_from: + is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value + else: + is_preview = False chunks = variable.value variables = {"chunks": chunks} if not chunks: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) + outputs = self._get_preview_output(dataset.chunk_structure, chunks) + # retrieve knowledge try: if is_preview: @@ -55,12 +70,12 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, - outputs={"result": "success"}, + outputs=outputs, ) - results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool) - outputs = {"result": results} + results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, + variable_pool=variable_pool) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results ) except KnowledgeIndexNodeError as e: @@ -81,24 +96,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): ) def _invoke_knowledge_index( - self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool + self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], + variable_pool: VariablePool ) -> Any: - dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) - if not dataset_id: - raise KnowledgeIndexNodeError("Dataset ID is required.") document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if not document_id: raise KnowledgeIndexNodeError("Document ID is required.") batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) if not batch: raise KnowledgeIndexNodeError("Batch is required.") - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset: - raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") - - document = db.session.query(Document).filter_by(id=document_id).first() + document = db.session.query(Document).filter_by(id=document_id.value).first() if not document: - raise KnowledgeIndexNodeError(f"Document {document_id} not found.") + raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.") index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) @@ -106,14 +115,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): # update document status document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.add(document) db.session.commit() return { "dataset_id": dataset.id, "dataset_name": dataset.name, - "batch": batch, + "batch": batch.value, "document_id": document.id, "document_name": document.name, - "created_at": document.created_at, + "created_at": document.created_at.timestamp(), "display_status": document.indexing_status, } + + def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + return index_processor.format_preview(chunks) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0733192c4f..c138266b14 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -41,10 +41,9 @@ conversation_variable_fields = { } pipeline_variable_fields = { - "id": fields.String, "label": fields.String, "variable": fields.String, - "type": fields.String(attribute="type.value"), + "type": fields.String, "belong_to_node_id": fields.String, "max_length": fields.Integer, "required": fields.Boolean, diff --git a/api/models/enums.py b/api/models/enums.py index 4434c3fec8..0afa204b1f 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -14,6 +14,8 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" + RAG_PIPELINE_RUN = "rag-pipeline-run" + RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging" class DraftVariableType(StrEnum): diff --git a/api/models/workflow.py b/api/models/workflow.py index b37b0febe8..f0aba3572a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -152,6 +152,7 @@ class Workflow(Base): created_by: str, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", ) -> "Workflow": @@ -166,6 +167,7 @@ class Workflow(Base): workflow.created_by = created_by workflow.environment_variables = environment_variables or [] workflow.conversation_variables = conversation_variables or [] + workflow.rag_pipeline_variables = rag_pipeline_variables or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment workflow.created_at = datetime.now(UTC).replace(tzinfo=None) @@ -340,7 +342,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], - "rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables], + "rag_pipeline_variables": self.rag_pipeline_variables, } return result @@ -553,6 +555,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum): SINGLE_STEP = "single-step" WORKFLOW_RUN = "workflow-run" + RAG_PIPELINE_RUN = "rag-pipeline-run" class WorkflowNodeExecutionStatus(StrEnum): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f1280375e0..6d3891799c 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) -from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeBaseUpdateConfiguration, + RagPipelineDatasetCreateEntity, +) from services.errors.account import InvalidActionError, NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError @@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService -from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo from services.tag_service import TagService from services.vector_service import VectorService from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task +from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task @@ -278,47 +281,6 @@ class DatasetService: db.session.commit() return dataset - @staticmethod - def create_rag_pipeline_dataset( - tenant_id: str, - rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, - ): - # check if dataset name already exists - if ( - db.session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() - ): - raise DatasetNameDuplicateError( - f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." - ) - - dataset = Dataset( - name=rag_pipeline_dataset_create_entity.name, - description=rag_pipeline_dataset_create_entity.description, - permission=rag_pipeline_dataset_create_entity.permission, - provider="vendor", - runtime_mode="rag-pipeline", - icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), - ) - with Session(db.engine) as session: - rag_pipeline_dsl_service = RagPipelineDslService(session) - account = cast(Account, current_user) - rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( - account=account, - import_mode=ImportMode.YAML_CONTENT.value, - yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=dataset, - ) - return { - "id": rag_pipeline_import_info.id, - "dataset_id": dataset.id, - "pipeline_id": rag_pipeline_import_info.pipeline_id, - "status": rag_pipeline_import_info.status, - "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, - "current_dsl_version": rag_pipeline_import_info.current_dsl_version, - "error": rag_pipeline_import_info.error, - } @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: @@ -529,6 +491,130 @@ class DatasetService: if action: deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + + @staticmethod + def update_rag_pipeline_dataset_settings(session: Session, + dataset: Dataset, + knowledge_base_setting: KnowledgeBaseUpdateConfiguration, + has_published: bool = False): + if not has_published: + dataset.chunk_structure = knowledge_base_setting.chunk_structure + index_method = knowledge_base_setting.index_method + dataset.indexing_technique = index_method.indexing_technique + if index_method == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=index_method.embedding_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=index_method.embedding_setting.embedding_model_name, + ) + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + dataset.collection_binding_id = dataset_collection_binding.id + elif index_method == "economy": + dataset.keyword_number = index_method.economy_setting.keyword_number + else: + raise ValueError("Invalid index method") + dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + session.add(dataset) + else: + if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: + raise ValueError("Chunk structure is not allowed to be updated.") + action = None + if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: + # if update indexing_technique + if knowledge_base_setting.index_method.indexing_technique == "economy": + raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") + elif knowledge_base_setting.index_method.indexing_technique == "high_quality": + action = "add" + # get embedding model setting + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + ) + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + dataset.collection_binding_id = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + else: + # add default plugin id to both setting sets, to make sure the plugin model provider is consistent + # Skip embedding model checks if not provided in the update request + if dataset.indexing_technique == "high_quality": + skip_embedding_update = False + try: + # Handle existing model provider + plugin_model_provider = dataset.embedding_model_provider + plugin_model_provider_str = None + if plugin_model_provider: + plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + + # Handle new model provider from request + new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name + new_plugin_model_provider_str = None + if new_plugin_model_provider: + new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + + # Only update embedding model if both values are provided and different from current + if ( + plugin_model_provider_str != new_plugin_model_provider_str + or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model + ): + action = "update" + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, skip updating it + # and keep the existing settings if available + # Skip the rest of the embedding model update + skip_embedding_update = True + if not skip_embedding_update: + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + ) + dataset.collection_binding_id = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + elif dataset.indexing_technique == "economy": + if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: + dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number + dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + session.add(dataset) + session.commit() + if action: + deal_dataset_index_update_task.delay(dataset.id, action) + @staticmethod def delete_dataset(dataset_id, user): diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index fbb9b25a75..54abc64547 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -4,29 +4,12 @@ from typing import Optional from flask_login import current_user from constants import HIDDEN_VALUE -from core import datasource -from core.datasource.__base import datasource_provider -from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.helper import encrypter -from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.entities.provider_entities import FormType from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.plugin.impl.datasource import PluginDatasourceManager -from core.provider_manager import ProviderManager +from extensions.ext_database import db from models.oauth import DatasourceProvider -from models.provider import ProviderType -from services.entities.model_provider_entities import ( - CustomConfigurationResponse, - CustomConfigurationStatus, - DefaultModelResponse, - ModelWithProviderEntityResponse, - ProviderResponse, - ProviderWithModelsResponse, - SimpleProviderEntityResponse, - SystemConfigurationResponse, -) -from extensions.database import db logger = logging.getLogger(__name__) @@ -115,16 +98,26 @@ class DatasourceProviderService: :param tenant_id: workspace id :param provider: provider name - :param datasource_name: datasource name :param plugin_id: plugin id :return: """ # Get all provider configurations of the current workspace - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id).first() + if not datasource_provider: + return None + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) + return copy_credentials def remove_datasource_credentials(self, @@ -136,11 +129,9 @@ class DatasourceProviderService: :param tenant_id: workspace id :param provider: provider name - :param datasource_name: datasource name :param plugin_id: plugin id :return: """ - # Get all provider configurations of the current workspace datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id).first() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 800bd24021..17416d51fd 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel): chunk_structure: str index_method: IndexMethod retrieval_setting: RetrievalSetting + + +class KnowledgeBaseUpdateConfiguration(BaseModel): + """ + Knowledge Base Update Configuration. + """ + index_method: IndexMethod + chunk_structure: str + retrieval_setting: RetrievalSetting \ No newline at end of file diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 14594be351..911086066a 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -69,9 +69,9 @@ class PipelineGenerateService: @classmethod def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) - return WorkflowAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_loop_generate( - app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().single_loop_generate( + pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 11071d82e7..9ea3cc678b 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: - pipeline_model: Pipeline = pipeline_built_in_template.pipeline + pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline + if not pipeline_model: + continue recommended_pipeline_result = { "id": pipeline_built_in_template.id, @@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "privacy_policy": pipeline_built_in_template.privacy_policy, "position": pipeline_built_in_template.position, } - dataset: Dataset = pipeline_model.dataset + dataset: Dataset | None = pipeline_model.dataset if dataset: recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure recommended_pipelines_results.append(recommended_pipeline_result) @@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - # get app detail + # get pipeline detail pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() if not pipeline or not pipeline.is_public: return None + dataset: Dataset | None = pipeline.dataset + if not dataset: + return None + return { "id": pipeline.id, "name": pipeline.name, - "icon": pipeline.icon, - "mode": pipeline.mode, + "icon": pipeline_template.icon, + "chunk_structure": dataset.chunk_structure, "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 5a69e69a16..9e7a1d7fe2 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -46,7 +46,8 @@ from models.workflow import ( WorkflowRun, WorkflowType, ) -from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -261,8 +262,7 @@ class RagPipelineService: session: Session, pipeline: Pipeline, account: Account, - marked_name: str = "", - marked_comment: str = "", + knowledge_base_setting: KnowledgeBaseUpdateConfiguration, ) -> Workflow: draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, @@ -282,18 +282,25 @@ class RagPipelineService: graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables, + environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables, - marked_name=marked_name, - marked_comment=marked_comment, + rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + marked_name="", + marked_comment="", ) - # commit db session changes session.add(workflow) - # trigger app workflow events TODO - # app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) - + # update dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + DatasetService.update_rag_pipeline_dataset_settings( + session=session, + dataset=dataset, + knowledge_base_setting=knowledge_base_setting, + has_published=pipeline.is_published + ) # return new workflow return workflow diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index acd364f6cd..c6751825cc 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -4,13 +4,14 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Optional +from typing import Optional, cast from urllib.parse import urlparse from uuid import uuid4 import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from flask_login import current_user from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -31,7 +32,10 @@ from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.workflow import Workflow -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -540,9 +544,6 @@ class RagPipelineDslService: # Update existing pipeline pipeline.name = pipeline_data.get("name", pipeline.name) pipeline.description = pipeline_data.get("description", pipeline.description) - pipeline.icon_type = icon_type - pipeline.icon = icon - pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background) pipeline.updated_by = account.id else: if account.current_tenant_id is None: @@ -554,12 +555,6 @@ class RagPipelineDslService: pipeline.tenant_id = account.current_tenant_id pipeline.name = pipeline_data.get("name", "") pipeline.description = pipeline_data.get("description", "") - pipeline.icon_type = icon_type - pipeline.icon = icon - pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF") - pipeline.enable_site = True - pipeline.enable_api = True - pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False) pipeline.created_by = account.id pipeline.updated_by = account.id @@ -674,26 +669,6 @@ class RagPipelineDslService: ) ] - @classmethod - def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None: - """ - Append model config export data - :param export_data: export data - :param pipeline: Pipeline instance - """ - app_model_config = pipeline.app_model_config - if not app_model_config: - raise ValueError("Missing app configuration, please check.") - - export_data["model_config"] = app_model_config.to_dict() - dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) - export_data["dependencies"] = [ - jsonable_encoder(d.model_dump()) - for d in DependenciesAnalysisService.generate_dependencies( - tenant_id=pipeline.tenant_id, dependencies=dependencies - ) - ] - @classmethod def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: """ @@ -863,3 +838,46 @@ class RagPipelineDslService: return pt.decode() except Exception: return None + + + @staticmethod + def create_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): + raise ValueError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + dataset = Dataset( + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag-pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), + ) + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + account = cast(Account, current_user) + rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( + account=account, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=rag_pipeline_dataset_create_entity.yaml_content, + dataset=dataset, + ) + return { + "id": rag_pipeline_import_info.id, + "dataset_id": dataset.id, + "pipeline_id": rag_pipeline_import_info.pipeline_id, + "status": rag_pipeline_import_info.status, + "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, + "current_dsl_version": rag_pipeline_import_info.current_dsl_version, + "error": rag_pipeline_import_info.error, + } diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py new file mode 100644 index 0000000000..dc266aef65 --- /dev/null +++ b/api/tasks/deal_dataset_index_update_task.py @@ -0,0 +1,171 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def deal_dataset_index_update_task(dataset_id: str, action: str): + """ + Async deal dataset from index + :param dataset_id: dataset_id + :param action: action + Usage: deal_dataset_index_update_task.delay(dataset_id, action) + """ + logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "upgrade": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + documents.append(document) + # save vector index + # clean keywords + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info( + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Deal dataset vector index failed") + finally: + db.session.close() From 797d044714f5e8da4afe9c7bf99cc56e6495067a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 29 May 2025 09:53:42 +0800 Subject: [PATCH 048/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 13 -------- api/core/plugin/impl/datasource.py | 11 +++---- .../nodes/datasource/datasource_node.py | 18 +++++++---- api/services/dataset_service.py | 27 +++++++++-------- api/services/rag_pipeline/rag_pipeline.py | 30 +++++++++++-------- 5 files changed, 50 insertions(+), 49 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fc6eab529a..09ff07646f 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.") - args = parser.parse_args() - - if not args.get("knowledge_base_setting"): - raise ValueError("Missing knowledge base setting.") - - knowledge_base_setting_data = args.get("knowledge_base_setting") - if not knowledge_base_setting_data: - raise ValueError("Missing knowledge base setting.") - - knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data) rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: pipeline = session.merge(pipeline) @@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource): session=session, pipeline=pipeline, account=current_user, - knowledge_base_setting=knowledge_base_setting, ) pipeline.is_published = True pipeline.workflow_id = workflow.id diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 7847218bb9..51d5489c4c 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -22,11 +22,12 @@ class PluginDatasourceManager(BasePluginClient): """ def transformer(json_response: dict[str, Any]) -> dict: - for provider in json_response.get("data", []): - declaration = provider.get("declaration", {}) or {} - provider_name = declaration.get("identity", {}).get("name") - for datasource in declaration.get("datasources", []): - datasource["identity"]["provider"] = provider_name + if json_response.get("data"): + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for datasource in declaration.get("datasources", []): + datasource["identity"]["provider"] = provider_name return json_response diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 8f841f9564..b44039298c 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.file import File +from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.variables import ArrayAnyVariable @@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): }, ) case DatasourceProviderType.LOCAL_FILE: - upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() + related_id = datasource_info.get("related_id") + if not related_id: + raise DatasourceNodeError( + "File is not exist" + ) + upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() if not upload_file: raise ValueError("Invalid upload file Info") @@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): extension="." + upload_file.extension, mime_type=upload_file.mime_type, tenant_id=self.tenant_id, - type=datasource_info.get("type", ""), - transfer_method=datasource_info.get("transfer_method", ""), + type=FileType.CUSTOM, + transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, related_id=upload_file.id, size=upload_file.size, storage_key=upload_file.key, ) - variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) + variable_pool.add([self.node_id, "file"], [file_info]) for key, value in datasource_info.items(): # construct new key list new_key_list = ["file", key] @@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "file_info": file_info, + "file_info": datasource_info, "datasource_type": datasource_type, }, ) @@ -220,7 +226,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6d3891799c..7621784d37 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeBaseUpdateConfiguration, + KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) from services.errors.account import InvalidActionError, NoPermissionError @@ -495,11 +496,11 @@ class DatasetService: @staticmethod def update_rag_pipeline_dataset_settings(session: Session, dataset: Dataset, - knowledge_base_setting: KnowledgeBaseUpdateConfiguration, + knowledge_configuration: KnowledgeConfiguration, has_published: bool = False): if not has_published: - dataset.chunk_structure = knowledge_base_setting.chunk_structure - index_method = knowledge_base_setting.index_method + dataset.chunk_structure = knowledge_configuration.chunk_structure + index_method = knowledge_configuration.index_method dataset.indexing_technique = index_method.indexing_technique if index_method == "high_quality": model_manager = ModelManager() @@ -519,26 +520,26 @@ class DatasetService: dataset.keyword_number = index_method.economy_setting.keyword_number else: raise ValueError("Invalid index method") - dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() session.add(dataset) else: - if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: + if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: raise ValueError("Chunk structure is not allowed to be updated.") action = None - if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: + if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique: # if update indexing_technique - if knowledge_base_setting.index_method.indexing_technique == "economy": + if knowledge_configuration.index_method.indexing_technique == "economy": raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_base_setting.index_method.indexing_technique == "high_quality": + elif knowledge_configuration.index_method.indexing_technique == "high_quality": action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.index_method.embedding_setting.embedding_model_name, ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -607,9 +608,9 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) elif dataset.indexing_technique == "economy": - if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: - dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number - dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number: + dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() session.add(dataset) session.commit() if action: diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9e7a1d7fe2..79e793118a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -47,7 +47,7 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -262,7 +262,6 @@ class RagPipelineService: session: Session, pipeline: Pipeline, account: Account, - knowledge_base_setting: KnowledgeBaseUpdateConfiguration, ) -> Workflow: draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, @@ -291,16 +290,23 @@ class RagPipelineService: # commit db session changes session.add(workflow) - # update dataset - dataset = pipeline.dataset - if not dataset: - raise ValueError("Dataset not found") - DatasetService.update_rag_pipeline_dataset_settings( - session=session, - dataset=dataset, - knowledge_base_setting=knowledge_base_setting, - has_published=pipeline.is_published - ) + graph = workflow.graph_dict + nodes = graph.get("nodes", []) + for node in nodes: + if node.get("data", {}).get("type") == "knowledge_index": + knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + + # update dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + DatasetService.update_rag_pipeline_dataset_settings( + session=session, + dataset=dataset, + knowledge_configuration=knowledge_configuration, + has_published=pipeline.is_published + ) # return new workflow return workflow From e7c48c0b6963cc190e19f21a92ae558c4da305ed Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 29 May 2025 23:04:04 +0800 Subject: [PATCH 049/155] r2 --- .../datasets/rag_pipeline/rag_pipeline.py | 44 ++++- .../app/apps/pipeline/pipeline_generator.py | 154 +++++++++------- api/core/entities/knowledge_entities.py | 23 +++ ...hemy_workflow_node_execution_repository.py | 6 +- api/fields/dataset_fields.py | 1 + api/models/dataset.py | 18 +- .../built_in/built_in_retrieval.py | 10 +- .../customized/customized_retrieval.py | 36 ++-- .../database/database_retrieval.py | 33 +--- .../pipeline_template_factory.py | 3 +- api/services/rag_pipeline/rag_pipeline.py | 44 +++-- .../rag_pipeline/rag_pipeline_dsl_service.py | 169 ++++++++++-------- 12 files changed, 339 insertions(+), 202 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 8c5f91cb7f..471ecbf070 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,5 +1,6 @@ import logging +import yaml from flask import request from flask_restful import Resource, reqparse from sqlalchemy.orm import Session @@ -12,10 +13,9 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from libs.login import login_required -from models.dataset import Pipeline, PipelineCustomizedTemplate +from models.dataset import PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService logger = logging.getLogger(__name__) @@ -84,8 +84,8 @@ class CustomizedPipelineTemplateApi(Resource): ) args = parser.parse_args() pipeline_template_info = PipelineTemplateInfoEntity(**args) - pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) - return pipeline_template, 200 + RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + return 200 @setup_required @login_required @@ -106,13 +106,41 @@ class CustomizedPipelineTemplateApi(Resource): ) if not template: raise ValueError("Customized pipeline template not found.") - pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() - if not pipeline: - raise ValueError("Pipeline not found.") - dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) + dsl = yaml.safe_load(template.yaml_content) return {"data": dsl}, 200 +class CustomizedPipelineTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, pipeline_id: str): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "icon_info", + type=dict, + location="json", + nullable=True, + ) + args = parser.parse_args() + rag_pipeline_service = RagPipelineService() + RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) + return 200 api.add_resource( PipelineTemplateListApi, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index e4c96775c8..b7e20cfd10 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -20,11 +20,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager from core.app.apps.pipeline.pipeline_runner import PipelineRunner -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline -from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository @@ -32,6 +32,7 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db +from fields.document_fields import dataset_and_document_fields from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline from models.enums import WorkflowRunTriggeredFrom @@ -54,7 +55,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Generator[Mapping | str, None, None] | None: ... + ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload def generate( @@ -101,23 +102,18 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, ) - + # Add null check for dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) - - for datasource_info in datasource_info_list: - workflow_run_id = str(uuid.uuid4()) - document_id = None - - # Add null check for dataset - dataset = pipeline.dataset - if not dataset: - raise ValueError("Pipeline dataset is required") - - if invoke_from == InvokeFrom.PUBLISHED: + documents = [] + if invoke_from == InvokeFrom.PUBLISHED: + for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) document = self._build_document( tenant_id=pipeline.tenant_id, @@ -132,9 +128,15 @@ class PipelineGenerator(BaseAppGenerator): document_form=dataset.chunk_structure, ) db.session.add(document) - db.session.commit() - document_id = document.id - # init application generate entity + documents.append(document) + db.session.commit() + + # run in child thread + for i, datasource_info in enumerate(datasource_info_list): + workflow_run_id = str(uuid.uuid4()) + document_id = None + if invoke_from == InvokeFrom.PUBLISHED: + document_id = documents[i].id application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), app_config=pipeline_config, @@ -159,7 +161,6 @@ class PipelineGenerator(BaseAppGenerator): workflow_run_id=workflow_run_id, ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) if invoke_from == InvokeFrom.DEBUGGER: @@ -183,6 +184,7 @@ class PipelineGenerator(BaseAppGenerator): ) if invoke_from == InvokeFrom.DEBUGGER: return self._generate( + flask_app=current_app._get_current_object(),# type: ignore pipeline=pipeline, workflow=workflow, user=user, @@ -194,21 +196,47 @@ class PipelineGenerator(BaseAppGenerator): workflow_thread_pool_id=workflow_thread_pool_id, ) else: - self._generate( - pipeline=pipeline, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - workflow_thread_pool_id=workflow_thread_pool_id, + # run in child thread + thread = threading.Thread( + target=self._generate, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "pipeline": pipeline, + "workflow": workflow, + "user": user, + "application_generate_entity": application_generate_entity, + "invoke_from": invoke_from, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "streaming": streaming, + "workflow_thread_pool_id": workflow_thread_pool_id, + }, ) - + thread.start() + # return batch, dataset, documents + return { + "batch": batch, + "dataset": PipelineDataset( + id=dataset.id, + name=dataset.name, + description=dataset.description, + chunk_structure=dataset.chunk_structure, + ).model_dump(), + "documents": [PipelineDocument( + id=document.id, + position=document.position, + data_source_info=document.data_source_info, + name=document.name, + indexing_status=document.indexing_status, + error=document.error, + enabled=document.enabled, + ).model_dump() for document in documents + ] + } def _generate( self, *, + flask_app: Flask, pipeline: Pipeline, workflow: Workflow, user: Union[Account, EndUser], @@ -232,40 +260,42 @@ class PipelineGenerator(BaseAppGenerator): :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id """ - # init queue manager - queue_manager = PipelineQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=AppMode.RAG_PIPELINE, - ) + print(user.id) + with flask_app.app_context(): + # init queue manager + queue_manager = PipelineQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=AppMode.RAG_PIPELINE, + ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": contextvars.copy_context(), - "workflow_thread_pool_id": workflow_thread_pool_id, - }, - ) + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - stream=streaming, - ) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + stream=streaming, + ) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, @@ -317,7 +347,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth=0, workflow_run_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) # Create workflow node execution repository @@ -338,6 +367,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( + flask_app=current_app._get_current_object(),# type: ignore pipeline=pipeline, workflow=workflow, user=user, @@ -399,7 +429,6 @@ class PipelineGenerator(BaseAppGenerator): single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), workflow_run_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -421,6 +450,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( + flask_app=current_app._get_current_object(),# type: ignore pipeline=pipeline, workflow=workflow, user=user, diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 90c9879733..f876c06b06 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -17,3 +17,26 @@ class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] qa_preview: Optional[list[QAPreviewDetail]] = None + + +class PipelineDataset(BaseModel): + id: str + name: str + description: str + chunk_structure: str + +class PipelineDocument(BaseModel): + id: str + position: int + data_source_info: dict + name: str + indexing_status: str + error: str + enabled: bool + + + +class PipelineGenerateResponse(BaseModel): + batch: str + dataset: PipelineDataset + documents: list[PipelineDocument] diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8d916a19db..cbc55474c6 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -253,6 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, + triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all WorkflowNodeExecution database models for a specific workflow run. @@ -274,7 +275,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) stmt = select(WorkflowNodeExecution).where( WorkflowNodeExecution.workflow_run_id == workflow_run_id, WorkflowNodeExecution.tenant_id == self._tenant_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + WorkflowNodeExecution.triggered_from == triggered_from, ) if self._app_id: @@ -308,6 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, + triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[NodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. @@ -325,7 +327,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) A list of NodeExecution instances """ # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) + db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) # Convert database models to domain models domain_models = [] diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 8d675e56fa..2871b3ec16 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -87,6 +87,7 @@ dataset_detail_fields = { "runtime_mode": fields.String, "chunk_structure": fields.String, "icon_info": fields.Nested(icon_info_fields), + "is_published": fields.Boolean, } dataset_query_detail_fields = { diff --git a/api/models/dataset.py b/api/models/dataset.py index 22703771d5..86216ffe98 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -152,6 +152,8 @@ class Dataset(Base): @property def doc_form(self): + if self.chunk_structure: + return self.chunk_structure document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form @@ -206,6 +208,13 @@ class Dataset(Base): "external_knowledge_api_name": external_knowledge_api.name, "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), } + @property + def is_published(self): + if self.pipeline_id: + pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + if pipeline: + return pipeline.is_published + return False @property def doc_metadata(self): @@ -1154,10 +1163,11 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - pipeline_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) + chunk_structure = db.Column(db.String(255), nullable=False) icon = db.Column(db.JSON, nullable=False) + yaml_content = db.Column(db.Text, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) position = db.Column(db.Integer, nullable=False) @@ -1166,9 +1176,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - @property - def pipeline(self): - return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] @@ -1180,11 +1187,12 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) - pipeline_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) + chunk_structure = db.Column(db.String(255), nullable=False) icon = db.Column(db.JSON, nullable=False) position = db.Column(db.Integer, nullable=False) + yaml_content = db.Column(db.Text, nullable=False) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index 70c72014f2..b0fa54115c 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -23,8 +23,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): result = self.fetch_pipeline_templates_from_builtin(language) return result - def get_pipeline_template_detail(self, pipeline_id: str): - result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id) + def get_pipeline_template_detail(self, template_id: str): + result = self.fetch_pipeline_template_detail_from_builtin(template_id) return result @classmethod @@ -54,11 +54,11 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod - def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]: """ Fetch pipeline template detail from builtin. - :param pipeline_id: Pipeline ID + :param template_id: Template ID :return: """ builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() - return builtin_data.get("pipeline_templates", {}).get(pipeline_id) + return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index de69373ba4..b6670b70cd 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,12 +1,13 @@ from typing import Optional from flask_login import current_user +import yaml from extensions.ext_database import db -from models.dataset import Pipeline, PipelineCustomizedTemplate -from services.app_dsl_service import AppDslService +from models.dataset import PipelineCustomizedTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): @@ -35,13 +36,26 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_templates = ( + pipeline_customized_templates = ( db.session.query(PipelineCustomizedTemplate) .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .all() ) + recommended_pipelines_results = [] + for pipeline_customized_template in pipeline_customized_templates: + + recommended_pipeline_result = { + "id": pipeline_customized_template.id, + "name": pipeline_customized_template.name, + "description": pipeline_customized_template.description, + "icon": pipeline_customized_template.icon, + "position": pipeline_customized_template.position, + "chunk_structure": pipeline_customized_template.chunk_structure, + } + recommended_pipelines_results.append(recommended_pipeline_result) + + return {"pipeline_templates": recommended_pipelines_results} - return {"pipeline_templates": pipeline_templates} @classmethod def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: @@ -57,15 +71,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - # get pipeline detail - pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() - if not pipeline or not pipeline.is_public: - return None - return { - "id": pipeline.id, - "name": pipeline.name, - "icon": pipeline.icon, - "mode": pipeline.mode, - "export_data": AppDslService.export_dsl(app_model=pipeline), + "id": pipeline_template.id, + "name": pipeline_template.name, + "icon": pipeline_template.icon, + "export_data": yaml.safe_load(pipeline_template.yaml_content), } diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 9ea3cc678b..8019dac0a8 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,7 +1,9 @@ from typing import Optional +import yaml + from extensions.ext_database import db -from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate +from models.dataset import PipelineBuiltInTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -36,24 +38,18 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: - pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline - if not pipeline_model: - continue recommended_pipeline_result = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, - "pipeline_id": pipeline_model.id, "description": pipeline_built_in_template.description, "icon": pipeline_built_in_template.icon, "copyright": pipeline_built_in_template.copyright, "privacy_policy": pipeline_built_in_template.privacy_policy, "position": pipeline_built_in_template.position, + "chunk_structure": pipeline_built_in_template.chunk_structure, } - dataset: Dataset | None = pipeline_model.dataset - if dataset: - recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure - recommended_pipelines_results.append(recommended_pipeline_result) + recommended_pipelines_results.append(recommended_pipeline_result) return {"pipeline_templates": recommended_pipelines_results} @@ -64,8 +60,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param pipeline_id: Pipeline ID :return: """ - from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService - # is in public recommended list pipeline_template = ( db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() @@ -74,19 +68,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - # get pipeline detail - pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() - if not pipeline or not pipeline.is_public: - return None - - dataset: Dataset | None = pipeline.dataset - if not dataset: - return None - return { - "id": pipeline.id, - "name": pipeline.name, + "id": pipeline_template.id, + "name": pipeline_template.name, "icon": pipeline_template.icon, - "chunk_structure": dataset.chunk_structure, - "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), + "chunk_structure": pipeline_template.chunk_structure, + "export_data": yaml.safe_load(pipeline_template.yaml_content), } diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py index aa8a6298d7..7b87ffe75b 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -1,4 +1,5 @@ from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval +from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -12,7 +13,7 @@ class PipelineTemplateRetrievalFactory: case PipelineTemplateType.REMOTE: return RemotePipelineTemplateRetrieval case PipelineTemplateType.CUSTOMIZED: - return DatabasePipelineTemplateRetrieval + return CustomizedPipelineTemplateRetrieval case PipelineTemplateType.DATABASE: return DatabasePipelineTemplateRetrieval case PipelineTemplateType.BUILTIN: diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 79e793118a..b3c32a7c78 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user -from sqlalchemy import select +from sqlalchemy import or_, select from sqlalchemy.orm import Session import contexts @@ -47,16 +47,19 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + PipelineTemplateInfoEntity, +) from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory class RagPipelineService: - @staticmethod + @classmethod def get_pipeline_templates( - type: str = "built-in", language: str = "en-US" - ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: + cls, type: str = "built-in", language: str = "en-US" + ) -> dict: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -64,14 +67,14 @@ class RagPipelineService: if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") - return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])] + return result else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) - return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])] + return result - @classmethod + @classmethod def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: """ Get pipeline template detail. @@ -684,7 +687,10 @@ class RagPipelineService: base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.app_id == pipeline.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + or_( + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value + ) ) if args.get("last_id"): @@ -765,8 +771,26 @@ class RagPipelineService: # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, + order_config=order_config, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) # Convert domain models to database models workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] return workflow_node_executions + + @classmethod + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + """ + Publish customized pipeline template + """ + pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + if not pipeline.workflow_id: + raise ValueError("Pipeline workflow not found") + workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + if not workflow: + raise ValueError("Workflow not found") + + db.session.commit() \ No newline at end of file diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c6751825cc..57e81e6f75 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -1,5 +1,7 @@ import base64 +from datetime import UTC, datetime import hashlib +import json import logging import uuid from collections.abc import Mapping @@ -31,13 +33,12 @@ from extensions.ext_redis import redis_client from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) from services.plugin.dependencies_analysis import DependenciesAnalysisService -from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) @@ -206,12 +207,12 @@ class RagPipelineDslService: status = _check_version_compatibility(imported_version) # Extract app data - pipeline_data = data.get("pipeline") + pipeline_data = data.get("rag_pipeline") if not pipeline_data: return RagPipelineImportInfo( id=import_id, status=ImportStatus.FAILED, - error="Missing pipeline data in YAML content", + error="Missing rag_pipeline data in YAML content", ) # If app_id is provided, check if it exists @@ -256,7 +257,7 @@ class RagPipelineDslService: if dependencies: check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] - # Create or update app + # Create or update pipeline pipeline = self._create_or_update_pipeline( pipeline=pipeline, data=data, @@ -278,7 +279,9 @@ class RagPipelineDslService: if node.get("data", {}).get("type") == "knowledge_index": knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) - if not dataset: + if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: + raise ValueError("Chunk structure is not compatible with the published pipeline") + else: dataset = Dataset( tenant_id=account.current_tenant_id, name=name, @@ -295,11 +298,6 @@ class RagPipelineDslService: runtime_mode="rag_pipeline", chunk_structure=knowledge_configuration.chunk_structure, ) - else: - dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() - dataset.runtime_mode = "rag_pipeline" - dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.index_method.indexing_technique == "high_quality": dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) @@ -540,11 +538,45 @@ class RagPipelineDslService: icon_type = "emoji" icon = str(pipeline_data.get("icon", "")) + + # Initialize pipeline based on mode + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for rag pipeline") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) + + + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, + tenant_id=account.current_tenant_id, + ) + ) + ] + if pipeline: # Update existing pipeline pipeline.name = pipeline_data.get("name", pipeline.name) pipeline.description = pipeline_data.get("description", pipeline.description) pipeline.updated_by = account.id + + else: if account.current_tenant_id is None: raise ValueError("Current tenant is not set") @@ -567,52 +599,44 @@ class RagPipelineDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) - - # Initialize pipeline based on mode - workflow_data = data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise ValueError("Missing workflow data for rag pipeline") - - environment_variables_list = workflow_data.get("environment_variables", []) - environment_variables = [ - variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables", []) - conversation_variables = [ - variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list - ] - rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) - - rag_pipeline_service = RagPipelineService() - current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - graph = workflow_data.get("graph", {}) - for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ - decrypted_id - for dataset_id in dataset_ids - if ( - decrypted_id := self.decrypt_dataset_id( - encrypted_data=dataset_id, - tenant_id=pipeline.tenant_id, - ) - ) - ] - rag_pipeline_service.sync_draft_workflow( - pipeline=pipeline, - graph=workflow_data.get("graph", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - rag_pipeline_variables=rag_pipeline_variables_list, + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() ) + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + features="{}", + type=WorkflowType.RAG_PIPELINE.value, + version="draft", + graph=json.dumps(graph), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables_list, + ) + db.session.add(workflow) + db.session.flush() + pipeline.workflow_id = workflow.id + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + workflow.rag_pipeline_variables = rag_pipeline_variables_list + # commit db session changes + db.session.commit() + + return pipeline @classmethod @@ -623,16 +647,19 @@ class RagPipelineDslService: :param include_secret: Whether include secret variable :return: """ + dataset = pipeline.dataset + if not dataset: + raise ValueError("Missing dataset for rag pipeline") + icon_info = dataset.icon_info export_data = { "version": CURRENT_DSL_VERSION, "kind": "rag_pipeline", "pipeline": { "name": pipeline.name, - "mode": pipeline.mode, - "icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon, - "icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background, + "icon": icon_info.get("icon", "📙") if icon_info else "📙", + "icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji", + "icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5", "description": pipeline.description, - "use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon, }, } @@ -647,8 +674,16 @@ class RagPipelineDslService: :param export_data: export data :param pipeline: Pipeline instance """ - rag_pipeline_service = RagPipelineService() - workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == "draft", + ) + .first() + ) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") @@ -855,14 +890,6 @@ class RagPipelineDslService: f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) - dataset = Dataset( - name=rag_pipeline_dataset_create_entity.name, - description=rag_pipeline_dataset_create_entity.description, - permission=rag_pipeline_dataset_create_entity.permission, - provider="vendor", - runtime_mode="rag-pipeline", - icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), - ) with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) account = cast(Account, current_user) @@ -870,11 +897,11 @@ class RagPipelineDslService: account=account, import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=dataset, + dataset=None, ) return { "id": rag_pipeline_import_info.id, - "dataset_id": dataset.id, + "dataset_id": rag_pipeline_import_info.dataset_id, "pipeline_id": rag_pipeline_import_info.pipeline_id, "status": rag_pipeline_import_info.status, "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, From cb5cfb2daefb0dddced71eefe17d40f1108704dd Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 00:03:43 +0800 Subject: [PATCH 050/155] r2 --- .../app/apps/pipeline/pipeline_generator.py | 96 +++++++++++++------ api/core/entities/knowledge_entities.py | 6 +- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index b7e20cfd10..19ded1696a 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -9,7 +9,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Optional, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app, has_request_context from pydantic import ValidationError from sqlalchemy.orm import sessionmaker @@ -185,8 +185,9 @@ class PipelineGenerator(BaseAppGenerator): if invoke_from == InvokeFrom.DEBUGGER: return self._generate( flask_app=current_app._get_current_object(),# type: ignore + context=contextvars.copy_context(), pipeline=pipeline, - workflow=workflow, + workflow_id=workflow.id, user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, @@ -197,22 +198,28 @@ class PipelineGenerator(BaseAppGenerator): ) else: # run in child thread - thread = threading.Thread( - target=self._generate, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "pipeline": pipeline, - "workflow": workflow, - "user": user, - "application_generate_entity": application_generate_entity, - "invoke_from": invoke_from, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "streaming": streaming, - "workflow_thread_pool_id": workflow_thread_pool_id, - }, - ) - thread.start() + context = contextvars.copy_context() + @copy_current_request_context + def worker_with_context(): + # Run the worker within the copied context + return context.run( + self._generate, + flask_app=current_app._get_current_object(), # type: ignore + context=context, + pipeline=pipeline, + workflow_id=workflow.id, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + worker_thread = threading.Thread(target=worker_with_context) + + worker_thread.start() # return batch, dataset, documents return { "batch": batch, @@ -225,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator): "documents": [PipelineDocument( id=document.id, position=document.position, - data_source_info=document.data_source_info, + data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, name=document.name, indexing_status=document.indexing_status, error=document.error, @@ -237,8 +244,9 @@ class PipelineGenerator(BaseAppGenerator): self, *, flask_app: Flask, + context: contextvars.Context, pipeline: Pipeline, - workflow: Workflow, + workflow_id: str, user: Union[Account, EndUser], application_generate_entity: RagPipelineGenerateEntity, invoke_from: InvokeFrom, @@ -260,26 +268,47 @@ class PipelineGenerator(BaseAppGenerator): :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id """ - print(user.id) + for var, val in context.items(): + var.set(val) + + # FIXME(-LAN-): Save current user before entering new app context + from flask import g + + saved_user = None + if has_request_context() and hasattr(g, "_login_user"): + saved_user = g._login_user with flask_app.app_context(): + # Restore user in new app context + if saved_user is not None: + from flask import g + + g._login_user = saved_user # init queue manager + workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first() + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") queue_manager = PipelineQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, app_mode=AppMode.RAG_PIPELINE, ) + context = contextvars.copy_context() + @copy_current_request_context + def worker_with_context(): + # Run the worker within the copied context + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + context=context, + queue_manager=queue_manager, + application_generate_entity=application_generate_entity, + workflow_thread_pool_id=workflow_thread_pool_id, + ) # new thread worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": contextvars.copy_context(), - "workflow_thread_pool_id": workflow_thread_pool_id, - }, + target=worker_with_context ) worker_thread.start() @@ -479,8 +508,17 @@ class PipelineGenerator(BaseAppGenerator): """ for var, val in context.items(): var.set(val) + from flask import g + + saved_user = None + if has_request_context() and hasattr(g, "_login_user"): + saved_user = g._login_user with flask_app.app_context(): try: + if saved_user is not None: + from flask import g + + g._login_user = saved_user # workflow app runner = PipelineRunner( application_generate_entity=application_generate_entity, diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index f876c06b06..3beea56e15 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -25,17 +25,17 @@ class PipelineDataset(BaseModel): description: str chunk_structure: str + class PipelineDocument(BaseModel): id: str position: int - data_source_info: dict + data_source_info: Optional[dict] = None name: str indexing_status: str - error: str + error: Optional[str] = None enabled: bool - class PipelineGenerateResponse(BaseModel): batch: str dataset: PipelineDataset From 69529fb16dcf68a6f0e55967bf732e27cfe2af52 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 00:37:27 +0800 Subject: [PATCH 051/155] r2 --- .../datasets/rag_pipeline/datasource_auth.py | 12 +++++++ api/core/plugin/entities/plugin_daemon.py | 1 + api/models/__init__.py | 3 ++ api/models/oauth.py | 1 + api/services/datasource_provider_service.py | 36 +++++++++++-------- .../rag_pipeline_manage_service.py | 11 +++++- 6 files changed, 49 insertions(+), 15 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 8894babcf7..ceb7a277e4 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -108,6 +108,18 @@ class DatasourceAuth(Resource): raise ValueError(str(ex)) return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + datasource_provider_service = DatasourceProviderService() + datasources = datasource_provider_service.get_datasource_credentials( + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id + ) + return {"result": datasources}, 200 class DatasourceAuthDeleteApi(Resource): @setup_required diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 6644706757..cc7dfb58ab 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel): provider: str plugin_unique_identifier: str plugin_id: str + is_authorized: bool = False declaration: DatasourceProviderEntityWithPlugin diff --git a/api/models/__init__.py b/api/models/__init__.py index f652449e98..63fe2747ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -56,6 +56,7 @@ from .model import ( TraceAppConfig, UploadFile, ) +from .oauth import DatasourceOauthParamConfig, DatasourceProvider from .provider import ( LoadBalancingModelConfig, Provider, @@ -123,6 +124,8 @@ __all__ = [ "DatasetProcessRule", "DatasetQuery", "DatasetRetrieverResource", + "DatasourceOauthParamConfig", + "DatasourceProvider", "DifySetup", "Document", "DocumentSegment", diff --git a/api/models/oauth.py b/api/models/oauth.py index fefe743195..d662a4b50c 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -28,6 +28,7 @@ class DatasourceProvider(Base): db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 54abc64547..ef9a56a66e 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -92,7 +92,7 @@ class DatasourceProviderService: return secret_input_form_variables - def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> Optional[dict]: + def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: """ get datasource credentials. @@ -102,22 +102,30 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, provider=provider, - plugin_id=plugin_id).first() - if not datasource_provider: - return None - encrypted_credentials = datasource_provider.encrypted_credentials - # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + plugin_id=plugin_id).all() + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) - # Obfuscate provider credentials - copy_credentials = encrypted_credentials.copy() - for key, value in copy_credentials.items(): - if key in credential_secret_variables: - copy_credentials[key] = encrypter.obfuscated_token(value) + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + } + ) - return copy_credentials + return copy_credentials_list def remove_datasource_credentials(self, diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py index 4d8d69f913..df6085fafa 100644 --- a/api/services/rag_pipeline/rag_pipeline_manage_service.py +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -1,5 +1,6 @@ from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.plugin.impl.datasource import PluginDatasourceManager +from services.datasource_provider_service import DatasourceProviderService class RagPipelineManageService: @@ -11,4 +12,12 @@ class RagPipelineManageService: # get all builtin providers manager = PluginDatasourceManager() - return manager.fetch_datasource_providers(tenant_id) + datasources = manager.fetch_datasource_providers(tenant_id) + for datasource in datasources: + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, + provider=datasource.provider, + plugin_id=datasource.plugin_id) + if credentials: + datasource.is_authorized = True + return datasources From 804e55824d46324d12e87af1efbf8295a70bc5f4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 00:37:36 +0800 Subject: [PATCH 052/155] r2 --- ...0_0033-c459994abfa8_add_pipeline_info_3.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py diff --git a/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py b/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py new file mode 100644 index 0000000000..0b010d535d --- /dev/null +++ b/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py @@ -0,0 +1,70 @@ +"""add_pipeline_info_3 + +Revision ID: c459994abfa8 +Revises: abb18a379e62 +Create Date: 2025-05-30 00:33:14.068312 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c459994abfa8' +down_revision = 'abb18a379e62' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('datasource_oauth_params', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx') + ) + op.create_table('datasource_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('auth_type', sa.String(length=255), nullable=False), + sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='datasource_provider_plugin_id_provider_idx') + ) + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False)) + batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False)) + batch_op.drop_column('pipeline_id') + + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False)) + batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False)) + batch_op.drop_column('pipeline_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.drop_column('yaml_content') + batch_op.drop_column('chunk_structure') + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.drop_column('yaml_content') + batch_op.drop_column('chunk_structure') + + op.drop_table('datasource_providers') + op.drop_table('datasource_oauth_params') + # ### end Alembic commands ### From 976b465e76ada7a18e44fd96793880b990e3f620 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 00:55:06 +0800 Subject: [PATCH 053/155] r2 --- api/models/oauth.py | 2 +- api/services/datasource_provider_service.py | 37 ++++++++++----------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/api/models/oauth.py b/api/models/oauth.py index d662a4b50c..9a070c2fbe 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -29,7 +29,7 @@ class DatasourceProvider(Base): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) - plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) + plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index ef9a56a66e..09c4cca706 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -35,28 +35,28 @@ class DatasourceProviderService: :param credentials: """ credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, - user_id=current_user.id, - provider=provider, - credentials=credentials) + user_id=current_user.id, + provider=provider, + credentials=credentials) if credential_valid: # Get all provider configurations of the current workspace datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).first() + provider=provider, + plugin_id=plugin_id).first() provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider=provider - ) + provider=provider + ) if not datasource_provider: for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider = DatasourceProvider(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials) + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials) db.session.add(datasource_provider) db.session.commit() else: @@ -91,7 +91,6 @@ class DatasourceProviderService: return secret_input_form_variables - def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: """ get datasource credentials. @@ -102,9 +101,11 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).all() + datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id + ).all() if not datasource_providers: return [] copy_credentials_list = [] @@ -127,7 +128,6 @@ class DatasourceProviderService: return copy_credentials_list - def remove_datasource_credentials(self, tenant_id: str, provider: str, @@ -141,9 +141,8 @@ class DatasourceProviderService: :return: """ datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).first() + provider=provider, + plugin_id=plugin_id).first() if datasource_provider: db.session.delete(datasource_provider) db.session.commit() - From 7284569c5f59e46ca948f0aa1c6dc679891c875e Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 30 May 2025 01:02:33 +0800 Subject: [PATCH 054/155] Update build-push.yml --- .github/workflows/build-push.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index cc735ae67c..f0f6cd66e6 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,6 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" + - "feat/r2" tags: - "*" From 631768ea1dccb5f24610f98dc4d2b07cab379ad4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 15:42:36 +0800 Subject: [PATCH 055/155] r2 --- .../nodes/knowledge_index/entities.py | 1 + .../knowledge_index/knowledge_index_node.py | 2 +- ...0_0052-e4fb49a4fe86_add_pipeline_info_4.py | 37 +++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 0d0da757d5..f342dbfb3d 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -155,4 +155,5 @@ class KnowledgeIndexNodeData(BaseNodeData): """ type: str = "knowledge-index" + chunk_structure: str index_chunk_variable_selector: list[str] diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index fef434e3ec..c0db13418f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -61,7 +61,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) - outputs = self._get_preview_output(dataset.chunk_structure, chunks) + outputs = self._get_preview_output(node_data.chunk_structure, chunks) # retrieve knowledge try: diff --git a/api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py b/api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py new file mode 100644 index 0000000000..5c10608c1b --- /dev/null +++ b/api/migrations/versions/2025_05_30_0052-e4fb49a4fe86_add_pipeline_info_4.py @@ -0,0 +1,37 @@ +"""add_pipeline_info_4 + +Revision ID: e4fb49a4fe86 +Revises: c459994abfa8 +Create Date: 2025-05-30 00:52:49.222558 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'e4fb49a4fe86' +down_revision = 'c459994abfa8' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.TEXT(), + existing_nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.TEXT(), + type_=sa.UUID(), + existing_nullable=False) + # ### end Alembic commands ### From 3fb02a79330eb58727e10392123ef1bf07038fc4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 30 May 2025 17:28:09 +0800 Subject: [PATCH 056/155] r2 --- .../console/datasets/rag_pipeline/rag_pipeline_workflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 09ff07646f..a8d2becb4c 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -542,6 +542,7 @@ class RagPipelineConfigApi(Resource): @login_required @account_initialization_required def get(self, pipeline_id): + return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } From 0486aa3445fdd494a8e077651a91130c954e0e76 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 13:30:51 +0800 Subject: [PATCH 057/155] r2 --- .../console/datasets/datasets_document.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 4 +- .../knowledge_index/knowledge_index_node.py | 18 +++++-- api/services/dataset_service.py | 52 +++++++++---------- .../rag_pipeline_entities.py | 19 +++---- api/services/rag_pipeline/rag_pipeline.py | 22 ++++---- .../rag_pipeline/rag_pipeline_dsl_service.py | 50 +++++++++--------- 7 files changed, 85 insertions(+), 82 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f7c04102a9..60fa1731ca 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -664,7 +664,7 @@ class DocumentDetailApi(DocumentResource): response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a8d2becb4c..fe91f01af6 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -39,8 +39,6 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline -from models.model import EndUser -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService @@ -542,7 +540,7 @@ class RagPipelineConfigApi(Resource): @login_required @account_initialization_required def get(self, pipeline_id): - + return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index c0db13418f..41a6c6141e 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -12,7 +12,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus from ..base import BaseNode @@ -61,11 +61,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) - outputs = self._get_preview_output(node_data.chunk_structure, chunks) - # retrieve knowledge + # index knowledge try: if is_preview: + outputs = self._get_preview_output(node_data.chunk_structure, chunks) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -116,6 +116,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) + #update document segment status + db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() return { diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49b6e208e4..6d0f8ec6a9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1,3 +1,4 @@ +from calendar import day_abbr import copy import datetime import json @@ -52,7 +53,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( SegmentUpdateArgs, ) from services.entities.knowledge_entities.rag_pipeline_entities import ( - KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) @@ -492,23 +492,23 @@ class DatasetService: if action: deal_dataset_vector_index_task.delay(dataset_id, action) return dataset - + @staticmethod def update_rag_pipeline_dataset_settings(session: Session, - dataset: Dataset, - knowledge_configuration: KnowledgeConfiguration, + dataset: Dataset, + knowledge_configuration: KnowledgeConfiguration, has_published: bool = False): + dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - index_method = knowledge_configuration.index_method - dataset.indexing_technique = index_method.indexing_technique - if index_method == "high_quality": + dataset.indexing_technique = knowledge_configuration.indexing_technique + if knowledge_configuration.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.embedding_model, ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -516,30 +516,30 @@ class DatasetService: embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id - elif index_method == "economy": - dataset.keyword_number = index_method.economy_setting.keyword_number + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() session.add(dataset) else: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: raise ValueError("Chunk structure is not allowed to be updated.") action = None - if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique: + if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.index_method.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == "economy": raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.index_method.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == "high_quality": action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.embedding_model, ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -567,7 +567,7 @@ class DatasetService: plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) # Handle new model provider from request - new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name + new_plugin_model_provider = knowledge_configuration.embedding_model_provider new_plugin_model_provider_str = None if new_plugin_model_provider: new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) @@ -575,16 +575,16 @@ class DatasetService: # Only update embedding model if both values are provided and different from current if ( plugin_model_provider_str != new_plugin_model_provider_str - or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model + or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" model_manager = ModelManager() try: embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.embedding_model, ) except ProviderTokenNotInitError: # If we can't get the embedding model, skip updating it @@ -608,14 +608,14 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) elif dataset.indexing_technique == "economy": - if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number: - dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() - session.add(dataset) + if dataset.keyword_number != knowledge_configuration.keyword_number: + dataset.keyword_number = knowledge_configuration.keyword_number + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + session.add(dataset) session.commit() if action: deal_dataset_index_update_task.delay(dataset.id, action) - + @staticmethod def delete_dataset(dataset_id, user): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 17416d51fd..778c394d5b 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -105,18 +105,11 @@ class IndexMethod(BaseModel): class KnowledgeConfiguration(BaseModel): """ - Knowledge Configuration. + Knowledge Base Configuration. """ - chunk_structure: str - index_method: IndexMethod - retrieval_setting: RetrievalSetting - - -class KnowledgeBaseUpdateConfiguration(BaseModel): - """ - Knowledge Base Update Configuration. - """ - index_method: IndexMethod - chunk_structure: str - retrieval_setting: RetrievalSetting \ No newline at end of file + indexing_technique: Literal["high_quality", "economy"] + embedding_model_provider: Optional[str] = "" + embedding_model: Optional[str] = "" + keyword_number: Optional[int] = 10 + retrieval_model: RetrievalSetting diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index b3c32a7c78..43451528db 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -74,7 +74,7 @@ class RagPipelineService: result = retrieval_instance.get_pipeline_templates(language) return result - @classmethod + @classmethod def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: """ Get pipeline template detail. @@ -284,7 +284,7 @@ class RagPipelineService: graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables, + environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables, rag_pipeline_variables=draft_workflow.rag_pipeline_variables, marked_name="", @@ -296,8 +296,8 @@ class RagPipelineService: graph = workflow.graph_dict nodes = graph.get("nodes", []) for node in nodes: - if node.get("data", {}).get("type") == "knowledge_index": - knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + if node.get("data", {}).get("type") == "knowledge-index": + knowledge_configuration = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) # update dataset @@ -306,8 +306,8 @@ class RagPipelineService: raise ValueError("Dataset not found") DatasetService.update_rag_pipeline_dataset_settings( session=session, - dataset=dataset, - knowledge_configuration=knowledge_configuration, + dataset=dataset, + knowledge_configuration=knowledge_configuration, has_published=pipeline.is_published ) # return new workflow @@ -771,14 +771,14 @@ class RagPipelineService: # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, - order_config=order_config, + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, + order_config=order_config, triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) # Convert domain models to database models workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] return workflow_node_executions - + @classmethod def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): """ @@ -792,5 +792,5 @@ class RagPipelineService: workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") - - db.session.commit() \ No newline at end of file + + db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 57e81e6f75..189ba0973f 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -1,10 +1,10 @@ import base64 -from datetime import UTC, datetime import hashlib import json import logging import uuid from collections.abc import Mapping +from datetime import UTC, datetime from enum import StrEnum from typing import Optional, cast from urllib.parse import urlparse @@ -292,20 +292,20 @@ class RagPipelineDslService: "background": icon_background, "url": icon_url, }, - indexing_technique=knowledge_configuration.index_method.indexing_technique, + indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, - retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode="rag_pipeline", chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.index_method.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( DatasetCollectionBinding.provider_name - == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name - == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -314,8 +314,8 @@ class RagPipelineDslService: if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( - provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, - model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + provider_name=knowledge_configuration.embedding_model_provider, + model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type="dataset", ) @@ -324,13 +324,13 @@ class RagPipelineDslService: dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( - knowledge_configuration.index_method.embedding_setting.embedding_model_name + knowledge_configuration.embedding_model ) dataset.embedding_model_provider = ( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name + knowledge_configuration.embedding_model_provider ) - elif knowledge_configuration.index_method.indexing_technique == "economy": - dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() @@ -426,25 +426,25 @@ class RagPipelineDslService: "background": icon_background, "url": icon_url, }, - indexing_technique=knowledge_configuration.index_method.indexing_technique, + indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, - retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode="rag_pipeline", chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = "rag_pipeline" dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.index_method.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( DatasetCollectionBinding.provider_name - == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name - == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -453,8 +453,8 @@ class RagPipelineDslService: if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( - provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, - model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + provider_name=knowledge_configuration.embedding_model_provider, + model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type="dataset", ) @@ -463,13 +463,13 @@ class RagPipelineDslService: dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( - knowledge_configuration.index_method.embedding_setting.embedding_model_name + knowledge_configuration.embedding_model ) dataset.embedding_model_provider = ( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name + knowledge_configuration.embedding_model_provider ) - elif knowledge_configuration.index_method.indexing_technique == "economy": - dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() From b8f3b23b1a409ccb99bcf268a62b3c7edb191cac Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 15:51:31 +0800 Subject: [PATCH 058/155] r2 --- api/configs/feature/__init__.py | 1 - .../app/apps/pipeline/pipeline_generator.py | 57 +++++++++++-------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index a3da5c1b49..9294b479e9 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -667,7 +667,6 @@ class MailConfig(BaseSettings): class RagEtlConfig(BaseSettings): """ Configuration for RAG ETL processes - """ # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config ETL_TYPE: str = Field( diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19ded1696a..d56e5243b3 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -55,7 +55,8 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... + ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: + ... @overload def generate( @@ -69,7 +70,8 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[False], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Mapping[str, Any]: ... + ) -> Mapping[str, Any]: + ... @overload def generate( @@ -83,7 +85,8 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool, call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + ... def generate( self, @@ -184,7 +187,7 @@ class PipelineGenerator(BaseAppGenerator): ) if invoke_from == InvokeFrom.DEBUGGER: return self._generate( - flask_app=current_app._get_current_object(),# type: ignore + flask_app=current_app._get_current_object(), # type: ignore context=contextvars.copy_context(), pipeline=pipeline, workflow_id=workflow.id, @@ -199,6 +202,7 @@ class PipelineGenerator(BaseAppGenerator): else: # run in child thread context = contextvars.copy_context() + @copy_current_request_context def worker_with_context(): # Run the worker within the copied context @@ -222,24 +226,25 @@ class PipelineGenerator(BaseAppGenerator): worker_thread.start() # return batch, dataset, documents return { - "batch": batch, - "dataset": PipelineDataset( - id=dataset.id, - name=dataset.name, - description=dataset.description, - chunk_structure=dataset.chunk_structure, - ).model_dump(), - "documents": [PipelineDocument( - id=document.id, - position=document.position, - data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, - name=document.name, - indexing_status=document.indexing_status, - error=document.error, - enabled=document.enabled, - ).model_dump() for document in documents - ] - } + "batch": batch, + "dataset": PipelineDataset( + id=dataset.id, + name=dataset.name, + description=dataset.description, + chunk_structure=dataset.chunk_structure, + ).model_dump(), + "documents": [PipelineDocument( + id=document.id, + position=document.position, + data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, + name=document.name, + indexing_status=document.indexing_status, + error=document.error, + enabled=document.enabled, + ).model_dump() for document in documents + ] + } + def _generate( self, *, @@ -268,6 +273,7 @@ class PipelineGenerator(BaseAppGenerator): :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id """ + print("jin ru la 1") for var, val in context.items(): var.set(val) @@ -279,6 +285,7 @@ class PipelineGenerator(BaseAppGenerator): saved_user = g._login_user with flask_app.app_context(): # Restore user in new app context + print("jin ru la 2") if saved_user is not None: from flask import g @@ -306,6 +313,7 @@ class PipelineGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, workflow_thread_pool_id=workflow_thread_pool_id, ) + # new thread worker_thread = threading.Thread( target=worker_with_context @@ -396,7 +404,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( - flask_app=current_app._get_current_object(),# type: ignore + flask_app=current_app._get_current_object(), # type: ignore pipeline=pipeline, workflow=workflow, user=user, @@ -479,7 +487,7 @@ class PipelineGenerator(BaseAppGenerator): ) return self._generate( - flask_app=current_app._get_current_object(),# type: ignore + flask_app=current_app._get_current_object(), # type: ignore pipeline=pipeline, workflow=workflow, user=user, @@ -506,6 +514,7 @@ class PipelineGenerator(BaseAppGenerator): :param workflow_thread_pool_id: workflow thread pool id :return: """ + print("jin ru la 3") for var, val in context.items(): var.set(val) from flask import g From 270edd43ab1eb97adfde275f514b54080fa38699 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 15:53:17 +0800 Subject: [PATCH 059/155] r2 --- api/configs/feature/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 9294b479e9..a3da5c1b49 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -667,6 +667,7 @@ class MailConfig(BaseSettings): class RagEtlConfig(BaseSettings): """ Configuration for RAG ETL processes + """ # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config ETL_TYPE: str = Field( From ab1730bbaacb2248dec3d30af0933c1b52b3bb67 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 16:51:21 +0800 Subject: [PATCH 060/155] r2 --- api/app.py | 22 +++++++++---------- .../app/apps/pipeline/pipeline_generator.py | 1 + 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/api/app.py b/api/app.py index 4f393f6c20..34889cff67 100644 --- a/api/app.py +++ b/api/app.py @@ -19,18 +19,18 @@ else: # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: from gevent import monkey + # + # # gevent + # monkey.patch_all() + # + # from grpc.experimental import gevent as grpc_gevent # type: ignore + # + # # grpc gevent + # grpc_gevent.init_gevent() - # gevent - monkey.patch_all() - - from grpc.experimental import gevent as grpc_gevent # type: ignore - - # grpc gevent - grpc_gevent.init_gevent() - - import psycogreen.gevent # type: ignore - - psycogreen.gevent.patch_psycopg() + # import psycogreen.gevent # type: ignore + # + # psycogreen.gevent.patch_psycopg() from app_factory import create_app diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index d56e5243b3..61d4b723e1 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -236,6 +236,7 @@ class PipelineGenerator(BaseAppGenerator): "documents": [PipelineDocument( id=document.id, position=document.position, + data_source_type=document.data_source_type, data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, name=document.name, indexing_status=document.indexing_status, From 573cd15e77e88f7d0efc986a63d8062176c7aa88 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 16:52:21 +0800 Subject: [PATCH 061/155] r2 --- api/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app.py b/api/app.py index 34889cff67..14807f5031 100644 --- a/api/app.py +++ b/api/app.py @@ -17,8 +17,8 @@ else: # It seems that JetBrains Python debugger does not work well with gevent, # so we need to disable gevent in debug mode. # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. - if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: - from gevent import monkey + # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: + #from gevent import monkey # # # gevent # monkey.patch_all() From 7b7f8ef51d4eb1d900cc0865d6e79c8c1acf10c2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 18:12:24 +0800 Subject: [PATCH 062/155] r2 --- api/core/rag/retrieval/retrieval_methods.py | 1 + api/core/workflow/nodes/knowledge_index/entities.py | 2 +- api/services/entities/knowledge_entities/knowledge_entities.py | 2 +- .../entities/knowledge_entities/rag_pipeline_entities.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index eaa00bca88..c7c6e60c8d 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -5,6 +5,7 @@ class RetrievalMethod(Enum): SEMANTIC_SEARCH = "semantic_search" FULL_TEXT_SEARCH = "full_text_search" HYBRID_SEARCH = "hybrid_search" + KEYWORD_SEARCH = "keyword_search" @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index f342dbfb3d..18a4f93970 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel): Retrieval Setting. """ - search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] + search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"] top_k: int score_threshold: Optional[float] = 0.5 score_threshold_enabled: bool = False diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index bb3be61f85..603064ca07 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -101,7 +101,7 @@ class WeightModel(BaseModel): class RetrievalModel(BaseModel): - search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"] reranking_enable: bool reranking_model: Optional[RerankingModel] = None reranking_mode: Optional[str] = None diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 778c394d5b..8da2e4aade 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel): Retrieval Setting. """ - search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] + search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"] top_k: int score_threshold: Optional[float] = 0.5 score_threshold_enabled: bool = False From 4130c506435503cbfc2007662b811d4c9f01f190 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 18:32:39 +0800 Subject: [PATCH 063/155] r2 --- api/core/rag/datasource/keyword/jieba/jieba.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index d6d0bd88b2..ca54290796 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -28,9 +28,11 @@ class Jieba(BaseKeyword): with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() + keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + for text in texts: keywords = keyword_table_handler.extract_keywords( - text.page_content, self._config.max_keywords_per_chunk + text.page_content, keyword_number ) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) @@ -49,17 +51,18 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") + keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: keywords = keyword_table_handler.extract_keywords( - text.page_content, self._config.max_keywords_per_chunk + text.page_content, keyword_number ) else: keywords = keyword_table_handler.extract_keywords( - text.page_content, self._config.max_keywords_per_chunk + text.page_content, keyword_number ) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) @@ -239,7 +242,9 @@ class Jieba(BaseKeyword): keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) + keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + + keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table( keyword_table or {}, segment.index_node_id, list(keywords) From 9cdd2cbb27839986e7a520484df9f2cc70b69e6f Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 19:02:57 +0800 Subject: [PATCH 064/155] r2 --- api/app.py | 25 ++++--- .../console/auth/data_source_oauth.py | 2 - api/controllers/console/auth/oauth.py | 8 +-- .../datasets/rag_pipeline/datasource_auth.py | 56 ++++++---------- .../datasets/rag_pipeline/rag_pipeline.py | 2 + .../rag_pipeline/rag_pipeline_workflow.py | 1 - .../app/apps/pipeline/pipeline_generator.py | 38 +++++------ api/core/app/apps/pipeline/pipeline_runner.py | 7 +- api/core/app/entities/app_invoke_entities.py | 1 + .../datasource/__base/datasource_runtime.py | 1 - api/core/datasource/datasource_manager.py | 4 +- api/core/plugin/impl/datasource.py | 47 ++++--------- .../rag/datasource/keyword/jieba/jieba.py | 24 +++---- .../index_processor/index_processor_base.py | 2 +- .../processor/paragraph_index_processor.py | 9 +-- .../processor/parent_child_index_processor.py | 16 ++--- .../processor/qa_index_processor.py | 9 +-- api/core/variables/variables.py | 13 +++- api/core/workflow/entities/variable_pool.py | 2 +- .../entities/workflow_node_execution.py | 1 + .../workflow/graph_engine/entities/graph.py | 3 +- .../workflow/graph_engine/graph_engine.py | 8 +-- .../nodes/datasource/datasource_node.py | 61 +++++++++-------- .../knowledge_index/knowledge_index_node.py | 15 +++-- api/models/dataset.py | 2 +- api/models/oauth.py | 2 +- api/services/dataset_service.py | 12 ++-- api/services/datasource_provider_service.py | 67 ++++++++++--------- .../rag_pipeline_entities.py | 1 + .../rag_pipeline/pipeline_generate_service.py | 1 - .../customized/customized_retrieval.py | 5 +- .../database/database_retrieval.py | 1 - api/services/rag_pipeline/rag_pipeline.py | 38 +++++------ .../rag_pipeline/rag_pipeline_dsl_service.py | 37 ++++------ .../rag_pipeline_manage_service.py | 8 +-- 35 files changed, 229 insertions(+), 300 deletions(-) diff --git a/api/app.py b/api/app.py index 14807f5031..e0a903b10d 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,3 @@ -import os import sys @@ -18,19 +17,19 @@ else: # so we need to disable gevent in debug mode. # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: - #from gevent import monkey - # - # # gevent - # monkey.patch_all() - # - # from grpc.experimental import gevent as grpc_gevent # type: ignore - # - # # grpc gevent - # grpc_gevent.init_gevent() + # from gevent import monkey + # + # # gevent + # monkey.patch_all() + # + # from grpc.experimental import gevent as grpc_gevent # type: ignore + # + # # grpc gevent + # grpc_gevent.init_gevent() - # import psycogreen.gevent # type: ignore - # - # psycogreen.gevent.patch_psycopg() + # import psycogreen.gevent # type: ignore + # + # psycogreen.gevent.patch_psycopg() from app_factory import create_app diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 5299064e17..1049f864c3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource): return {"result": "success"}, 200 - - api.add_resource(OAuthDataSource, "/oauth/data-source/") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ed595f5d3d..395367c9e2 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,24 +4,19 @@ from typing import Optional import requests from flask import current_app, redirect, request -from flask_login import current_user from flask_restful import Resource from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, NotFound, Unauthorized +from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import languages -from controllers.console.wraps import account_initialization_required, setup_required -from core.plugin.impl.oauth import OAuthHandler from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip -from libs.login import login_required from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus -from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError @@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account - api.add_resource(OAuthLogin, "/oauth/login/") api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index ceb7a277e4..96cb3f5602 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,12 +1,9 @@ - from flask import redirect, request from flask_login import current_user # type: ignore from flask_restful import ( # type: ignore Resource, # type: ignore - marshal_with, reqparse, ) -from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config @@ -16,7 +13,6 @@ from controllers.console.wraps import ( setup_required, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db from libs.login import login_required @@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource): if not current_user.is_editor: raise Forbidden() # get all plugin oauth configs - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() + plugin_oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() @@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource): if system_credentials: system_credentials["redirect_url"] = redirect_url response = oauth_handler.get_authorization_url( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, - system_credentials=system_credentials + current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials ) return response.model_dump() + class DatasourceOauthCallback(Resource): @setup_required @login_required @account_initialization_required def get(self, provider, plugin_id): oauth_handler = OAuthHandler() - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() + plugin_oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) if not plugin_oauth_config: raise NotFound() credentials = oauth_handler.get_credentials( @@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource): plugin_id, provider, system_credentials=plugin_oauth_config.system_credentials, - request=request + request=request, ) datasource_provider = DatasourceProvider( - plugin_id=plugin_id, - provider=provider, - auth_type="oauth", - encrypted_credentials=credentials + plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials ) db.session.add(datasource_provider) db.session.commit() return redirect(f"{dify_config.CONSOLE_WEB_URL}") + class DatasourceAuth(Resource): @setup_required @login_required @@ -99,28 +88,27 @@ class DatasourceAuth(Resource): try: datasource_provider_service.datasource_provider_credentials_validate( - tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id, - credentials=args["credentials"] + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id, + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) return {"result": "success"}, 201 - + @setup_required @login_required @account_initialization_required def get(self, provider, plugin_id): datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id ) return {"result": datasources}, 200 - + + class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource): raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id ) return {"result": "success"}, 200 + # Import Rag Pipeline api.add_resource( DatasourcePluginOauthApi, @@ -149,4 +136,3 @@ api.add_resource( DatasourceAuth, "/auth/datasource/provider//plugin/", ) - diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 471ecbf070..1b869d9847 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource): dsl = yaml.safe_load(template.yaml_content) return {"data": dsl}, 200 + class CustomizedPipelineTemplateApi(Resource): @setup_required @login_required @@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource): RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) return 200 + api.add_resource( PipelineTemplateListApi, "/rag/pipeline/templates", diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fe91f01af6..d7ed5d475d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource): @login_required @account_initialization_required def get(self, pipeline_id): - return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 61d4b723e1..5fb5bff2a9 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db -from fields.document_fields import dataset_and_document_fields from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline from models.enums import WorkflowRunTriggeredFrom @@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: - ... + ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload def generate( @@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[False], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Mapping[str, Any]: - ... + ) -> Mapping[str, Any]: ... @overload def generate( @@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool, call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: - ... + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( self, @@ -233,17 +229,19 @@ class PipelineGenerator(BaseAppGenerator): description=dataset.description, chunk_structure=dataset.chunk_structure, ).model_dump(), - "documents": [PipelineDocument( - id=document.id, - position=document.position, - data_source_type=document.data_source_type, - data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, - name=document.name, - indexing_status=document.indexing_status, - error=document.error, - enabled=document.enabled, - ).model_dump() for document in documents - ] + "documents": [ + PipelineDocument( + id=document.id, + position=document.position, + data_source_type=document.data_source_type, + data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, + name=document.name, + indexing_status=document.indexing_status, + error=document.error, + enabled=document.enabled, + ).model_dump() + for document in documents + ], } def _generate( @@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator): ) # new thread - worker_thread = threading.Thread( - target=worker_with_context - ) + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 8d90e7ee3e..4582dcbb0d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner): if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: + if ( + rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id + and rag_pipeline_variable.variable in inputs + ): rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] variable_pool = VariablePool( @@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner): continue real_run_nodes.append(node) for edge in edges: - if edge.get("source") in exclude_node_ids : + if edge.get("source") in exclude_node_ids: continue real_edges.append(edge) graph_config = dict(graph_config) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index f346994b30..75693be5ea 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): """ RAG Pipeline Application Generate Entity. """ + # pipeline config pipeline_config: WorkflowUIBasedAppConfig datasource_type: str diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index 9ddc25a637..264145d261 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -5,7 +5,6 @@ from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import DatasourceInvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom class DatasourceRuntime(BaseModel): diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 8c74aeb320..46b36d8349 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -46,7 +46,7 @@ class DatasourceManager: if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") - match (datasource_type): + match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: controller = OnlineDocumentDatasourcePluginProviderController( entity=provider_entity.declaration, @@ -98,5 +98,3 @@ class DatasourceManager: tenant_id, datasource_type, ).get_datasource(datasource_name) - - diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 51d5489c4c..ea357d85b2 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient): "X-Plugin-ID": datasource_provider_id.plugin_id, "Content-Type": "application/json", }, - ) for resp in response: @@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient): "identity": { "author": "langgenius", "name": "langgenius/file/file", - "label": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - }, + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", - "description": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - } + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, }, "credentials_schema": [], "provider_type": "local_file", - "datasources": [{ - "identity": { - "author": "langgenius", - "name": "upload-file", - "provider": "langgenius", - "label": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - } - }, - "parameters": [], - "description": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" + "datasources": [ + { + "identity": { + "author": "langgenius", + "name": "upload-file", + "provider": "langgenius", + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, + }, + "parameters": [], + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, } - }] - } + ], + }, } diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ca54290796..be1765feee 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -28,12 +28,12 @@ class Jieba(BaseKeyword): with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) for text in texts: - keywords = keyword_table_handler.extract_keywords( - text.page_content, keyword_number - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) keyword_table = self._add_text_to_keyword_table( @@ -51,19 +51,17 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") - keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords( - text.page_content, keyword_number - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) else: - keywords = keyword_table_handler.extract_keywords( - text.page_content, keyword_number - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) keyword_table = self._add_text_to_keyword_table( @@ -242,7 +240,9 @@ class Jieba(BaseKeyword): keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: - keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) segment.keywords = list(keywords) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 72e4923b58..ff6f843a28 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -38,7 +38,7 @@ class BaseIndexProcessor(ABC): @abstractmethod def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): raise NotImplementedError - + @abstractmethod def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 559bc5d59b..eee8353214 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) - def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: paragraph = GeneralStructureChunk(**chunks) preview = [] for content in paragraph.general_chunks: preview.append({"content": content}) - return { - "preview": preview, - "total_segments": len(paragraph.general_chunks) - } \ No newline at end of file + return {"preview": preview, "total_segments": len(paragraph.general_chunks)} diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 7a3f8f1c63..158fc819ee 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper -from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor): parent_childs = ParentChildStructureChunk(**chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: - preview.append( - { - "content": parent_child.parent_content, - "child_chunks": parent_child.child_contents - - } - ) - return { - "preview": preview, - "total_segments": len(parent_childs.parent_child_chunks) - } \ No newline at end of file + preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) + return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)} diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index b415596254..407f1b6f6d 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,8 @@ import logging import re import threading import uuid -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional import pandas as pd from flask import Flask, current_app @@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset, Document as DatasetDocument +from models.dataset import Dataset from services.entities.knowledge_entities.knowledge_entities import Rule @@ -160,10 +161,10 @@ class QAIndexProcessor(BaseIndexProcessor): doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs - + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): pass - + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: return {"preview": chunks} diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index c0952383a9..1fe0e36a47 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + class RAGPipelineVariable(BaseModel): belong_to_node_id: str = Field(description="belong to which node id, shared means public") type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") label: str = Field(description="label") description: str | None = Field(description="description", default="") variable: str = Field(description="variable key", default="") - max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) + max_length: int | None = Field( + description="max length, applicable to text-input, paragraph, and file-list", default=0 + ) default_value: str | None = Field(description="default value", default="") placeholder: str | None = Field(description="placeholder", default="") unit: str | None = Field(description="unit, applicable to Number", default="") tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) + allowed_file_types: list[str] | None = Field( + description="image, document, audio, video, custom.", default_factory=list + ) allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) + allowed_file_upload_methods: list[str] | None = Field( + description="remote_url, local_file, tool_file.", default_factory=list + ) required: bool = Field(description="optional, default false", default=False) options: list[str] | None = Field(default_factory=list) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 319833145e..21ea26862a 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -49,7 +49,7 @@ class VariablePool(BaseModel): ) rag_pipeline_variables: Mapping[str, Any] = Field( description="RAG pipeline variables.", - default_factory=dict, + default_factory=dict, ) def __init__( diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index 773f5b777b..10271f6062 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): AGENT_LOG = "agent_log" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" + DATASOURCE_INFO = "datasource_info" LOOP_ID = "loop_id" LOOP_INDEX = "loop_index" PARALLEL_ID = "parallel_id" diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 7062fc4565..16bf847189 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -122,7 +122,6 @@ class Graph(BaseModel): root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} - for node_config in node_configs: node_id = node_config.get("id") if not node_id: @@ -142,7 +141,7 @@ class Graph(BaseModel): ( node_config.get("id") for node_config in root_node_configs - if node_config.get("data", {}).get("type", "") == NodeType.START.value + if node_config.get("data", {}).get("type", "") == NodeType.START.value or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value ), None, diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ced1acfdd2..86654e6fac 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -317,10 +317,10 @@ class GraphEngine: raise e # It may not be necessary, but it is necessary. :) - if ( - self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() - in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] - ): + if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [ + NodeType.END.value, + NodeType.KNOWLEDGE_INDEX.value, + ]: break previous_route_node_state = route_node_state diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index b44039298c..92b2daea54 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen from core.file import File from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError -from core.variables.segments import ArrayAnySegment, FileSegment +from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import UploadFile -from models.workflow import WorkflowNodeExecutionStatus +from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from .entities import DatasourceNodeData from .exc import DatasourceNodeError, DatasourceParameterError @@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): try: from core.datasource.datasource_manager import DatasourceManager - if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") @@ -66,13 +66,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) except DatasourceNodeError as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to get datasource runtime: {str(e)}", - error_type=type(e).__name__, - ) - + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ "online_document": online_document_result.result.model_dump(), "datasource_type": datasource_type, @@ -112,18 +111,16 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "website": datasource_info, - "datasource_type": datasource_type, + "website": datasource_info, + "datasource_type": datasource_type, }, ) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: - raise DatasourceNodeError( - "File is not exist" - ) + raise DatasourceNodeError("File is not exist") upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() if not upload_file: raise ValueError("Invalid upload file Info") @@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): # construct new key list new_key_list = ["file", key] self._append_variables_recursively( - variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value + variable_pool=variable_pool, + node_id=self.node_id, + variable_key_list=new_key_list, + variable_value=value, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "file_info": datasource_info, - "datasource_type": datasource_type, - }, - ) - case _: - raise DatasourceNodeError( - f"Unsupported datasource provider: {datasource_type}" + "file_info": datasource_info, + "datasource_type": datasource_type, + }, ) + case _: + raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, ) @@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, ) @@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - - def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): + def _append_variables_recursively( + self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue + ): """ Append variables recursively :param node_id: node id diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 41a6c6141e..a1ee3aa823 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -6,7 +6,6 @@ from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): process_data=None, outputs=outputs, ) - results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, - variable_pool=variable_pool) + results = self._invoke_knowledge_index( + dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results ) @@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): ) def _invoke_knowledge_index( - self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], - variable_pool: VariablePool + self, + dataset: Dataset, + node_data: KnowledgeIndexNodeData, + chunks: Mapping[str, Any], + variable_pool: VariablePool, ) -> Any: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if not document_id: @@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) - #update document segment status + # update document segment status db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document.id, DocumentSegment.dataset_id == dataset.id, diff --git a/api/models/dataset.py b/api/models/dataset.py index 86216ffe98..d2fdd5e900 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -208,6 +208,7 @@ class Dataset(Base): "external_knowledge_api_name": external_knowledge_api.name, "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), } + @property def is_published(self): if self.pipeline_id: @@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" __table_args__ = ( diff --git a/api/models/oauth.py b/api/models/oauth.py index 9a070c2fbe..2fb34f0ac9 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,4 +1,3 @@ - from datetime import datetime from sqlalchemy.dialects.postgresql import JSONB @@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] provider: Mapped[str] = db.Column(db.String(255), nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6d0f8ec6a9..133e3765f7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1,4 +1,3 @@ -from calendar import day_abbr import copy import datetime import json @@ -7,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Optional, cast +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func, select @@ -282,7 +281,6 @@ class DatasetService: db.session.commit() return dataset - @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() @@ -494,10 +492,9 @@ class DatasetService: return dataset @staticmethod - def update_rag_pipeline_dataset_settings(session: Session, - dataset: Dataset, - knowledge_configuration: KnowledgeConfiguration, - has_published: bool = False): + def update_rag_pipeline_dataset_settings( + session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False + ): dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure @@ -616,7 +613,6 @@ class DatasetService: if action: deal_dataset_index_update_task.delay(dataset.id, action) - @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 09c4cca706..ccafc5555c 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from flask_login import current_user @@ -22,11 +21,9 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() - def datasource_provider_credentials_validate(self, - tenant_id: str, - provider: str, - plugin_id: str, - credentials: dict) -> None: + def datasource_provider_credentials_validate( + self, tenant_id: str, provider: str, plugin_id: str, credentials: dict + ) -> None: """ validate datasource provider credentials. @@ -34,29 +31,30 @@ class DatasourceProviderService: :param provider: :param credentials: """ - credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, - user_id=current_user.id, - provider=provider, - credentials=credentials) + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials + ) if credential_valid: # Get all provider configurations of the current workspace - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).first() + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .first() + ) - provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider=provider - ) + provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) if not datasource_provider: for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value credentials[key] = encrypter.encrypt_token(tenant_id, value) - datasource_provider = DatasourceProvider(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials, + ) db.session.add(datasource_provider) db.session.commit() else: @@ -101,11 +99,15 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter( - DatasourceProvider.tenant_id == tenant_id, - DatasourceProvider.provider == provider, - DatasourceProvider.plugin_id == plugin_id - ).all() + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) if not datasource_providers: return [] copy_credentials_list = [] @@ -128,10 +130,7 @@ class DatasourceProviderService: return copy_credentials_list - def remove_datasource_credentials(self, - tenant_id: str, - provider: str, - plugin_id: str) -> None: + def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None: """ remove datasource credentials. @@ -140,9 +139,11 @@ class DatasourceProviderService: :param plugin_id: plugin id :return: """ - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).first() + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .first() + ) if datasource_provider: db.session.delete(datasource_provider) db.session.commit() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 8da2e4aade..620fb2426a 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel): """ Knowledge Base Configuration. """ + chunk_structure: str indexing_technique: Literal["high_quality", "economy"] embedding_model_provider: Optional[str] = "" diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 911086066a..da67801877 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -3,7 +3,6 @@ from typing import Any, Union from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from models.dataset import Pipeline from models.model import Account, App, EndUser diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index b6670b70cd..3ede75309d 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,13 +1,12 @@ from typing import Optional -from flask_login import current_user import yaml +from flask_login import current_user from extensions.ext_database import db from models.dataset import PipelineCustomizedTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): @@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ) recommended_pipelines_results = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, @@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} - @classmethod def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: """ diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 8019dac0a8..741384afc2 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 43451528db..07697c9851 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore +from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi class RagPipelineService: @classmethod - def get_pipeline_templates( - cls, type: str = "built-in", language: str = "en-US" - ) -> dict: + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -308,7 +306,7 @@ class RagPipelineService: session=session, dataset=dataset, knowledge_configuration=knowledge_configuration, - has_published=pipeline.is_published + has_published=pipeline.is_published, ) # return new workflow return workflow @@ -444,12 +442,10 @@ class RagPipelineService: ) if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = ( - datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) + online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), ) return { "result": [page.model_dump() for page in online_document_result.result], @@ -470,7 +466,6 @@ class RagPipelineService: else: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -689,8 +684,8 @@ class RagPipelineService: WorkflowRun.app_id == pipeline.id, or_( WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value - ) + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, + ), ) if args.get("last_id"): @@ -763,18 +758,17 @@ class RagPipelineService: # Use the repository to get the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - app_id=pipeline.id, - user=user, - triggered_from=None + session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None ) # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, - order_config=order_config, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) - # Convert domain models to database models + node_executions = repository.get_by_workflow_run( + workflow_run_id=run_id, + order_config=order_config, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + # Convert domain models to database models workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] return workflow_node_executions diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 189ba0973f..2e1ed57908 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -279,7 +279,11 @@ class RagPipelineDslService: if node.get("data", {}).get("type") == "knowledge_index": knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) - if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: + if ( + dataset + and pipeline.is_published + and dataset.chunk_structure != knowledge_configuration.chunk_structure + ): raise ValueError("Chunk structure is not compatible with the published pipeline") else: dataset = Dataset( @@ -304,8 +308,7 @@ class RagPipelineDslService: .filter( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, - DatasetCollectionBinding.model_name - == knowledge_configuration.embedding_model, + DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -323,12 +326,8 @@ class RagPipelineDslService: db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id - dataset.embedding_model = ( - knowledge_configuration.embedding_model - ) - dataset.embedding_model_provider = ( - knowledge_configuration.embedding_model_provider - ) + dataset.embedding_model = knowledge_configuration.embedding_model + dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id @@ -443,8 +442,7 @@ class RagPipelineDslService: .filter( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, - DatasetCollectionBinding.model_name - == knowledge_configuration.embedding_model, + DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -462,12 +460,8 @@ class RagPipelineDslService: db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id - dataset.embedding_model = ( - knowledge_configuration.embedding_model - ) - dataset.embedding_model_provider = ( - knowledge_configuration.embedding_model_provider - ) + dataset.embedding_model = knowledge_configuration.embedding_model + dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id @@ -538,7 +532,6 @@ class RagPipelineDslService: icon_type = "emoji" icon = str(pipeline_data.get("icon", "")) - # Initialize pipeline based on mode workflow_data = data.get("workflow") if not workflow_data or not isinstance(workflow_data, dict): @@ -554,7 +547,6 @@ class RagPipelineDslService: ] rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) - graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: @@ -576,7 +568,6 @@ class RagPipelineDslService: pipeline.description = pipeline_data.get("description", pipeline.description) pipeline.updated_by = account.id - else: if account.current_tenant_id is None: raise ValueError("Current tenant is not set") @@ -636,7 +627,6 @@ class RagPipelineDslService: # commit db session changes db.session.commit() - return pipeline @classmethod @@ -874,7 +864,6 @@ class RagPipelineDslService: except Exception: return None - @staticmethod def create_rag_pipeline_dataset( tenant_id: str, @@ -886,9 +875,7 @@ class RagPipelineDslService: .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) .first() ): - raise ValueError( - f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." - ) + raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py index df6085fafa..0908d30c12 100644 --- a/api/services/rag_pipeline/rag_pipeline_manage_service.py +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -12,12 +12,12 @@ class RagPipelineManageService: # get all builtin providers manager = PluginDatasourceManager() - datasources = manager.fetch_datasource_providers(tenant_id) + datasources = manager.fetch_datasource_providers(tenant_id) for datasource in datasources: datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, - provider=datasource.provider, - plugin_id=datasource.plugin_id) + credentials = datasource_provider_service.get_datasource_credentials( + tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id + ) if credentials: datasource.is_authorized = True return datasources From abcca11479b6ea9f3c757d3697aaf53e6b22e1f9 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 19:10:40 +0800 Subject: [PATCH 065/155] r2 --- api/core/app/apps/pipeline/pipeline_generator.py | 4 ++-- .../nodes/knowledge_index/knowledge_index_node.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 8 +++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 5fb5bff2a9..82560175b7 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -29,8 +29,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index a1ee3aa823..1195df1b7f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -12,7 +12,7 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from models.workflow import WorkflowNodeExecutionStatus +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from ..base import BaseNode from .entities import KnowledgeIndexNodeData diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 07697c9851..9c4a054184 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -23,6 +23,10 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base.node import BaseNode @@ -30,7 +34,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.repository.workflow_node_execution_repository import OrderConfig +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -40,8 +44,6 @@ from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( Workflow, - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowType, From c09c8c6e5b0655fe7a2a636c6d407aa62a6e66fe Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 4 Jun 2025 15:12:05 +0800 Subject: [PATCH 066/155] r2 --- .../datasets/rag_pipeline/datasource_auth.py | 52 ++++++++++++++----- api/core/datasource/datasource_manager.py | 12 ++--- api/core/plugin/impl/datasource.py | 25 ++++++--- api/models/oauth.py | 4 +- api/services/datasource_provider_service.py | 18 ++++--- 5 files changed, 75 insertions(+), 36 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 96cb3f5602..c78b36c3b9 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -24,7 +24,13 @@ class DatasourcePluginOauthApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, plugin_id): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] # Check user role first if not current_user.is_editor: raise Forbidden() @@ -35,7 +41,7 @@ class DatasourcePluginOauthApi(Resource): if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() - redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback" + redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" system_credentials = plugin_oauth_config.system_credentials if system_credentials: system_credentials["redirect_url"] = redirect_url @@ -49,7 +55,13 @@ class DatasourceOauthCallback(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, plugin_id): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] oauth_handler = OAuthHandler() plugin_oauth_config = ( db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() @@ -76,11 +88,13 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider, plugin_id): + def post(self): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -89,8 +103,8 @@ class DatasourceAuth(Resource): try: datasource_provider_service.datasource_provider_credentials_validate( tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id, + provider=args["provider"], + plugin_id=args["plugin_id"], credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: @@ -101,10 +115,16 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, plugin_id): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, + provider=args["provider"], + plugin_id=args["plugin_id"] ) return {"result": datasources}, 200 @@ -113,12 +133,18 @@ class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, provider, plugin_id): + def delete(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() if not current_user.is_editor: raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, + provider=args["provider"], + plugin_id=args["plugin_id"] ) return {"result": "success"}, 200 @@ -126,13 +152,13 @@ class DatasourceAuthDeleteApi(Resource): # Import Rag Pipeline api.add_resource( DatasourcePluginOauthApi, - "/oauth/datasource/provider//plugin/", + "/oauth/plugin/datasource", ) api.add_resource( DatasourceOauthCallback, - "/oauth/datasource/provider//plugin//callback", + "/oauth/plugin/datasource/callback", ) api.add_resource( DatasourceAuth, - "/auth/datasource/provider//plugin/", + "/auth/plugin/datasource", ) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 46b36d8349..838fee5b96 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -24,7 +24,7 @@ class DatasourceManager: @classmethod def get_datasource_plugin_provider( - cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType + cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType ) -> DatasourcePluginProviderController: """ get the datasource plugin provider @@ -38,13 +38,13 @@ class DatasourceManager: with contexts.datasource_plugin_providers_lock.get(): datasource_plugin_providers = contexts.datasource_plugin_providers.get() - if provider in datasource_plugin_providers: - return datasource_plugin_providers[provider] + if provider_id in datasource_plugin_providers: + return datasource_plugin_providers[provider_id] manager = PluginDatasourceManager() - provider_entity = manager.fetch_datasource_provider(tenant_id, provider) + provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) if not provider_entity: - raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") + raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: @@ -71,7 +71,7 @@ class DatasourceManager: case _: raise ValueError(f"Unsupported datasource type: {datasource_type}") - datasource_plugin_providers[provider] = controller + datasource_plugin_providers[provider_id] = controller return controller diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index ea357d85b2..f469b51224 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -40,16 +40,25 @@ class PluginDatasourceManager(BasePluginClient): ) local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - return [local_file_datasource_provider] + response + all_response = [local_file_datasource_provider] + response - def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: + for provider in all_response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.datasources: + tool.identity.provider = provider.declaration.identity.name + + return all_response + + def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity: """ Fetch datasource provider for the given tenant and plugin. """ - if provider == "langgenius/file/file": + if provider_id == "langgenius/file/file": return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - tool_provider_id = ToolProviderID(provider) + tool_provider_id = ToolProviderID(provider_id) def transformer(json_response: dict[str, Any]) -> dict: data = json_response.get("data") @@ -225,13 +234,13 @@ class PluginDatasourceManager(BasePluginClient): def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", - "plugin_id": "langgenius/file/file", - "provider": "langgenius", + "plugin_id": "langgenius/file", + "provider": "file", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", "declaration": { "identity": { "author": "langgenius", - "name": "langgenius/file/file", + "name": "file", "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, @@ -243,7 +252,7 @@ class PluginDatasourceManager(BasePluginClient): "identity": { "author": "langgenius", "name": "upload-file", - "provider": "langgenius", + "provider": "file", "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, }, "parameters": [], diff --git a/api/models/oauth.py b/api/models/oauth.py index 2fb34f0ac9..d823bcae16 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -25,12 +25,12 @@ class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), - db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), + db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) - plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) + plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index ccafc5555c..71edec760f 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -38,11 +38,14 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_provider = ( db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key") .first() ) - provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}" + ) if not datasource_provider: for key, value in credentials.items(): if key in provider_credential_secret_variables: @@ -73,14 +76,16 @@ class DatasourceProviderService: else: raise CredentialsValidateFailedError() - def extract_secret_variables(self, tenant_id: str, provider: str) -> list[str]: + def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: """ Extract secret input form variables. :param credential_form_schemas: :return: """ - datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, provider=provider) + datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, + provider_id=provider_id + ) credential_form_schemas = datasource_provider.declaration.credentials_schema secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: @@ -94,8 +99,7 @@ class DatasourceProviderService: get datasource credentials. :param tenant_id: workspace id - :param provider: provider name - :param plugin_id: plugin id + :param provider_id: provider id :return: """ # Get all provider configurations of the current workspace @@ -114,7 +118,7 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=provider) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() From b95ecaf8a3f54424b559dcb4e3f65254f79a6865 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:17:39 +0800 Subject: [PATCH 067/155] Update build-push.yml --- .github/workflows/build-push.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index b5628fcbad..0442a121c0 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,8 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" - - "feat/r2" - - "feat/rag-pipeline" + - "deploy/rag-dev" tags: - "*" From 133193e7d06ed2a1ac0d3f306f37eb3f8c5bb510 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 4 Jun 2025 16:23:12 +0800 Subject: [PATCH 068/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 55 ++++++++++++++++ .../nodes/datasource/datasource_node.py | 2 +- .../workflow/nodes/datasource/entities.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 62 ++++++++++++++++++- 4 files changed, 118 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d7ed5d475d..b66a747121 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -677,6 +677,53 @@ class PublishedRagPipelineSecondStepApi(Resource): "variables": variables, } +class PublishedRagPipelineFirstStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get first step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_published_first_step_parameters(pipeline=pipeline, node_id=node_id) + return { + "variables": variables, + } + +class DraftRagPipelineFirstStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get first step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_draft_first_step_parameters(pipeline=pipeline, node_id=node_id) + return { + "variables": variables, + } class DraftRagPipelineSecondStepApi(Resource): @setup_required @@ -862,7 +909,15 @@ api.add_resource( PublishedRagPipelineSecondStepApi, "/rag/pipelines//workflows/published/processing/parameters", ) +api.add_resource( + PublishedRagPipelineFirstStepApi, + "/rag/pipelines//workflows/published/pre-processing/parameters", +) api.add_resource( DraftRagPipelineSecondStepApi, "/rag/pipelines//workflows/draft/processing/parameters", ) +api.add_resource( + DraftRagPipelineFirstStepApi, + "/rag/pipelines//workflows/draft/pre-processing/parameters", +) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 92b2daea54..8c76fc161d 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -59,7 +59,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): raise DatasourceNodeError("Datasource type is not set") datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=node_data.provider_id, + provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, datasource_type=DatasourceProviderType.value_of(datasource_type), diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index dee3c1d2fb..b182928baa 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -7,7 +7,7 @@ from core.workflow.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): - provider_id: str + plugin_id: str provider_name: str # redundancy provider_type: str datasource_name: Optional[str] = "local_file" diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9c4a054184..80b961851a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,4 +1,5 @@ import json +import re import threading import time from collections.abc import Callable, Generator, Sequence @@ -434,14 +435,19 @@ class RagPipelineService: datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) if not datasource_node_data: raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=datasource_node_data.get("provider_id"), + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", datasource_name=datasource_node_data.get("datasource_name"), tenant_id=pipeline.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) + if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( @@ -648,6 +654,60 @@ class RagPipelineService: if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" ] return datasource_provider_variables + + def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + """ + Get first step parameters of rag pipeline + """ + + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # get second step node + datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + if not datasource_node_data: + raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + if datasource_parameters: + datasource_parameters_map = { + item["variable"]: item for item in datasource_parameters + } + else: + datasource_parameters_map = {} + variables = datasource_node_data.get("variables", {}) + user_input_variables = [] + for key, value in variables.items(): + if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(datasource_parameters_map.get(key, {})) + return user_input_variables + + def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + """ + Get first step parameters of rag pipeline + """ + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + # get second step node + datasource_node_data = draft_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + if not datasource_node_data: + raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + if datasource_parameters: + datasource_parameters_map = { + item["variable"]: item for item in datasource_parameters + } + else: + datasource_parameters_map = {} + variables = datasource_node_data.get("variables", {}) + user_input_variables = [] + for key, value in variables.items(): + if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(datasource_parameters_map.get(key, {})) + return user_input_variables def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ From a82ab1d152d44ae6d14e34de24ab2f53d7b0486f Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 4 Jun 2025 16:51:23 +0800 Subject: [PATCH 069/155] r2 --- api/controllers/console/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index c55d3fbb66..f17c28dcd4 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -84,6 +84,7 @@ from .datasets import ( website, ) from .datasets.rag_pipeline import ( + datasource_auth, rag_pipeline, rag_pipeline_datasets, rag_pipeline_import, From 8a147a00e8f5ac4dfdf623acada2e136629f239c Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 4 Jun 2025 17:29:39 +0800 Subject: [PATCH 070/155] r2 --- .../datasets/rag_pipeline/datasource_auth.py | 43 ++++++++-- api/models/oauth.py | 2 +- api/services/datasource_provider_service.py | 83 ++++++++++++------- 3 files changed, 91 insertions(+), 37 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index c78b36c3b9..bc91343c71 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -122,18 +122,18 @@ class DatasourceAuth(Resource): args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=args["provider"], + tenant_id=current_user.current_tenant_id, + provider=args["provider"], plugin_id=args["plugin_id"] ) return {"result": datasources}, 200 -class DatasourceAuthDeleteApi(Resource): +class DatasourceAuthUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self): + def delete(self, auth_id: str): parser = reqparse.RequestParser() parser.add_argument("provider", type=str, required=True, nullable=False, location="args") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") @@ -142,12 +142,38 @@ class DatasourceAuthDeleteApi(Resource): raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=args["provider"], + tenant_id=current_user.current_tenant_id, + auth_id=auth_id, + provider=args["provider"], plugin_id=args["plugin_id"] ) return {"result": "success"}, 200 + @setup_required + @login_required + @account_initialization_required + def patch(self, auth_id: str): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + if not current_user.is_editor: + raise Forbidden() + try: + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.update_datasource_credentials( + tenant_id=current_user.current_tenant_id, + auth_id=auth_id, + provider=args["provider"], + plugin_id=args["plugin_id"], + credentials=args["credentials"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + # Import Rag Pipeline api.add_resource( @@ -162,3 +188,8 @@ api.add_resource( DatasourceAuth, "/auth/plugin/datasource", ) + +api.add_resource( + DatasourceAuth, + "/auth/plugin/datasource/", +) diff --git a/api/models/oauth.py b/api/models/oauth.py index d823bcae16..938a309069 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -25,7 +25,7 @@ class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), - db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_auth_type_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 71edec760f..1344dfa9fe 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -38,7 +38,7 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_provider = ( db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key") + .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key") .first() ) @@ -46,33 +46,19 @@ class DatasourceProviderService: tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" ) - if not datasource_provider: - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - credentials[key] = encrypter.encrypt_token(tenant_id, value) - datasource_provider = DatasourceProvider( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials, - ) - db.session.add(datasource_provider) - db.session.commit() - else: - original_credentials = datasource_provider.encrypted_credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) - credentials[key] = encrypter.encrypt_token(tenant_id, original_value) - else: - credentials[key] = encrypter.encrypt_token(tenant_id, value) - - datasource_provider.encrypted_credentials = credentials - db.session.commit() + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials, + ) + db.session.add(datasource_provider) + db.session.commit() else: raise CredentialsValidateFailedError() @@ -133,8 +119,45 @@ class DatasourceProviderService: ) return copy_credentials_list + + def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: + """ + update datasource credentials. + """ + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials + ) + if credential_valid: + # Get all provider configurations of the current workspace + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) + .first() + ) - def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None: + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}" + ) + if not datasource_provider: + raise ValueError("Datasource provider not found") + else: + original_credentials = datasource_provider.encrypted_credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) + credentials[key] = encrypter.encrypt_token(tenant_id, original_value) + else: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + datasource_provider.encrypted_credentials = credentials + db.session.commit() + else: + raise CredentialsValidateFailedError() + + def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ remove datasource credentials. @@ -145,7 +168,7 @@ class DatasourceProviderService: """ datasource_provider = ( db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) .first() ) if datasource_provider: From fbca9010f3ea0c81eef25efcc65a0973e3e36f52 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 4 Jun 2025 17:39:31 +0800 Subject: [PATCH 071/155] r2 --- .../console/datasets/rag_pipeline/datasource_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index bc91343c71..912981db01 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -190,6 +190,6 @@ api.add_resource( ) api.add_resource( - DatasourceAuth, + DatasourceAuthUpdateDeleteApi, "/auth/plugin/datasource/", ) From 5fe5da7c1d12e13107f971cb4c06036a46a08dd2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 5 Jun 2025 11:12:06 +0800 Subject: [PATCH 072/155] r2 --- api/core/plugin/impl/datasource.py | 8 ++++---- api/services/datasource_provider_service.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index f469b51224..775f36b506 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -202,12 +202,12 @@ class PluginDatasourceManager(BasePluginClient): raise Exception("No response from plugin daemon") def validate_provider_credentials( - self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] + self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] ) -> bool: """ validate the credentials of the provider """ - datasource_provider_id = GenericProviderID(provider) + # datasource_provider_id = GenericProviderID(provider_id) response = self._request_with_plugin_daemon_response_stream( "POST", @@ -216,12 +216,12 @@ class PluginDatasourceManager(BasePluginClient): data={ "user_id": user_id, "data": { - "provider": datasource_provider_id.provider_name, + "provider": provider, "credentials": credentials, }, }, headers={ - "X-Plugin-ID": datasource_provider_id.plugin_id, + "X-Plugin-ID": plugin_id, "Content-Type": "application/json", }, ) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 1344dfa9fe..ffc056921b 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -32,7 +32,7 @@ class DatasourceProviderService: :param credentials: """ credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials + tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials ) if credential_valid: # Get all provider configurations of the current workspace @@ -119,13 +119,13 @@ class DatasourceProviderService: ) return copy_credentials_list - + def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: """ update datasource credentials. """ credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials + tenant_id=tenant_id, user_id=current_user.id, provider=provider,plugin_id=plugin_id, credentials=credentials ) if credential_valid: # Get all provider configurations of the current workspace @@ -156,7 +156,7 @@ class DatasourceProviderService: db.session.commit() else: raise CredentialsValidateFailedError() - + def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ remove datasource credentials. From 3e0a10b7ed0b86480d3e9842b74f863c9e7e30e5 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 5 Jun 2025 11:45:53 +0800 Subject: [PATCH 073/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 80b961851a..b6de92f201 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -430,11 +430,16 @@ class RagPipelineService: raise ValueError("Workflow not initialized") # run draft workflow node + datasource_node_data = None start_at = time.perf_counter() - - datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + datasource_nodes = published_workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break if not datasource_node_data: raise ValueError("Datasource node data not found") + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) for key, value in datasource_parameters.items(): if not user_inputs.get(key): @@ -665,7 +670,12 @@ class RagPipelineService: raise ValueError("Workflow not initialized") # get second step node - datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + datasource_node_data = None + datasource_nodes = published_workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break if not datasource_node_data: raise ValueError("Datasource node data not found") datasource_parameters = datasource_node_data.get("datasource_parameters", {}) @@ -692,7 +702,12 @@ class RagPipelineService: raise ValueError("Workflow not initialized") # get second step node - datasource_node_data = draft_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + datasource_node_data = None + datasource_nodes = draft_workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break if not datasource_node_data: raise ValueError("Datasource node data not found") datasource_parameters = datasource_node_data.get("datasource_parameters", {}) From 8a86a2c81727e3c499fbbd5d6287c8e20c19abab Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 5 Jun 2025 14:09:50 +0800 Subject: [PATCH 074/155] r2 --- .../app/apps/pipeline/pipeline_generator.py | 6 ++-- api/core/entities/knowledge_entities.py | 1 + ...5_1356-d466c551816f_add_pipeline_info_5.py | 35 +++++++++++++++++++ api/models/oauth.py | 2 +- api/services/datasource_provider_service.py | 2 +- 5 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 82560175b7..54491609bb 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -157,7 +157,7 @@ class PipelineGenerator(BaseAppGenerator): stream=streaming, invoke_from=invoke_from, call_depth=call_depth, - workflow_run_id=workflow_run_id, + workflow_execution_id=workflow_run_id, ) contexts.plugin_tool_providers.set({}) @@ -379,7 +379,7 @@ class PipelineGenerator(BaseAppGenerator): stream=streaming, invoke_from=InvokeFrom.DEBUGGER, call_depth=0, - workflow_run_id=str(uuid.uuid4()), + workflow_execution_id=str(uuid.uuid4()), ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -461,7 +461,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), - workflow_run_id=str(uuid.uuid4()), + workflow_execution_id=str(uuid.uuid4()), ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 3beea56e15..63fce06005 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -29,6 +29,7 @@ class PipelineDataset(BaseModel): class PipelineDocument(BaseModel): id: str position: int + data_source_type: str data_source_info: Optional[dict] = None name: str indexing_status: str diff --git a/api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py b/api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py new file mode 100644 index 0000000000..56860d1f80 --- /dev/null +++ b/api/migrations/versions/2025_06_05_1356-d466c551816f_add_pipeline_info_5.py @@ -0,0 +1,35 @@ +"""add_pipeline_info_5 + +Revision ID: d466c551816f +Revises: e4fb49a4fe86 +Create Date: 2025-06-05 13:56:05.962215 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd466c551816f' +down_revision = 'e4fb49a4fe86' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('datasource_provider_plugin_id_provider_idx'), type_='unique') + batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_index('datasource_provider_auth_type_provider_idx') + batch_op.create_unique_constraint(batch_op.f('datasource_provider_plugin_id_provider_idx'), ['plugin_id', 'provider']) + + # ### end Alembic commands ### diff --git a/api/models/oauth.py b/api/models/oauth.py index 938a309069..b1b09e5d45 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -25,7 +25,7 @@ class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), - db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_auth_type_provider_idx"), + db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index ffc056921b..5bede09a64 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -104,7 +104,7 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=provider) + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() From 9e84a5321df3239886d0c1eae311bbb597677535 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 5 Jun 2025 14:55:09 +0800 Subject: [PATCH 075/155] r2 --- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 54491609bb..a2123fdc49 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -403,7 +403,7 @@ class PipelineGenerator(BaseAppGenerator): return self._generate( flask_app=current_app._get_current_object(), # type: ignore pipeline=pipeline, - workflow=workflow, + workflow_id=workflow.id, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 4582dcbb0d..50dc8d8fad 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -99,7 +99,7 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.USER_ID: user_id, SystemVariableKey.APP_ID: app_config.app_id, SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, SystemVariableKey.BATCH: self.application_generate_entity.batch, SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, From c084b57933da1fb9ab19639053d4f1183d183f06 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 5 Jun 2025 15:28:44 +0800 Subject: [PATCH 076/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 81 ++++++++++------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index b6de92f201..5541349edb 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -373,9 +373,6 @@ class RagPipelineService: tenant_id=pipeline.tenant_id, node_id=node_id, ) - - workflow_node_execution.app_id = pipeline.id - workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = draft_workflow.id db.session.add(workflow_node_execution) @@ -409,8 +406,6 @@ class RagPipelineService: node_id=node_id, ) - workflow_node_execution.app_id = pipeline.id - workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = published_workflow.id db.session.add(workflow_node_execution) @@ -568,18 +563,17 @@ class RagPipelineService: node_run_result = None error = e.error - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = tenant_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value - workflow_node_execution.index = 1 - workflow_node_execution.node_id = node_id - workflow_node_execution.node_type = node_instance.node_type - workflow_node_execution.title = node_instance.node_data.title - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=node_instance.workflow_id, + index=1, + node_id=node_id, + node_type=node_instance.node_type, + title=node_instance.node_data.title, + elapsed_time=time.perf_counter() - start_at, + finished_at=datetime.now(UTC).replace(tzinfo=None), + created_at=datetime.now(UTC).replace(tzinfo=None), + ) if run_succeeded and node_run_result: # create workflow node execution inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None @@ -590,20 +584,18 @@ class RagPipelineService: ) outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None - workflow_node_execution.inputs = json.dumps(inputs) - workflow_node_execution.process_data = json.dumps(process_data) - workflow_node_execution.outputs = json.dumps(outputs) - workflow_node_execution.execution_metadata = ( - json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None - ) + workflow_node_execution.inputs = inputs + workflow_node_execution.process_data = process_data + workflow_node_execution.outputs = outputs + workflow_node_execution.metadata = node_run_result.metadata if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: - workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION workflow_node_execution.error = node_run_result.error else: # create workflow node execution - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED workflow_node_execution.error = error return workflow_node_execution @@ -678,18 +670,18 @@ class RagPipelineService: break if not datasource_node_data: raise ValueError("Datasource node data not found") - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - if datasource_parameters: - datasource_parameters_map = { - item["variable"]: item for item in datasource_parameters + variables = datasource_node_data.get("variables", {}) + if variables: + variables_map = { + item["variable"]: item for item in variables } else: - datasource_parameters_map = {} - variables = datasource_node_data.get("variables", {}) + variables_map = {} + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) user_input_variables = [] - for key, value in variables.items(): + for key, value in datasource_parameters.items(): if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(datasource_parameters_map.get(key, {})) + user_input_variables.append(variables_map.get(key, {})) return user_input_variables def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: @@ -710,18 +702,19 @@ class RagPipelineService: break if not datasource_node_data: raise ValueError("Datasource node data not found") - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - if datasource_parameters: - datasource_parameters_map = { - item["variable"]: item for item in datasource_parameters + variables = datasource_node_data.get("variables", {}) + if variables: + variables_map = { + item["variable"]: item for item in variables } else: - datasource_parameters_map = {} - variables = datasource_node_data.get("variables", {}) + variables = {} + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + user_input_variables = [] - for key, value in variables.items(): + for key, value in datasource_parameters.items(): if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(datasource_parameters_map.get(key, {})) + user_input_variables.append(variables_map.get(key, {})) return user_input_variables def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: @@ -845,10 +838,8 @@ class RagPipelineService: order_config=order_config, triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, ) - # Convert domain models to database models - workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] - return workflow_node_executions + return list(node_executions) @classmethod def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): From b8ef1d958524941151a1c50679c3b26f1c26eb40 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 5 Jun 2025 16:43:47 +0800 Subject: [PATCH 077/155] r2 --- .../console/datasets/rag_pipeline/rag_pipeline_workflow.py | 1 + api/services/rag_pipeline/rag_pipeline.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index b66a747121..7710a6770b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -39,6 +39,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline +from models.model import EndUser from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 5541349edb..79925a1c1b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -676,7 +676,7 @@ class RagPipelineService: item["variable"]: item for item in variables } else: - variables_map = {} + return [] datasource_parameters = datasource_node_data.get("datasource_parameters", {}) user_input_variables = [] for key, value in datasource_parameters.items(): @@ -708,7 +708,7 @@ class RagPipelineService: item["variable"]: item for item in variables } else: - variables = {} + return [] datasource_parameters = datasource_node_data.get("datasource_parameters", {}) user_input_variables = [] From 70432952fd2513851d7909b7837b5128cffbe6a1 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 10:40:06 +0800 Subject: [PATCH 078/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 51 +++++++++++++++++-- api/services/rag_pipeline/rag_pipeline.py | 17 ++++--- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 7710a6770b..fef10b79a7 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -301,7 +301,7 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -class RagPipelineDatasourceNodeRunApi(Resource): +class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required @@ -336,10 +336,50 @@ class RagPipelineDatasourceNodeRunApi(Resource): user_inputs=inputs, account=current_user, datasource_type=datasource_type, + is_published=True ) return result +class RagPipelineDrafDatasourceNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False + ) + + return result class RagPipelinePublishedNodeRunApi(Resource): @setup_required @@ -851,8 +891,13 @@ api.add_resource( "/rag/pipelines//workflows/draft/nodes//run", ) api.add_resource( - RagPipelineDatasourceNodeRunApi, - "/rag/pipelines//workflows/datasource/nodes//run", + RagPipelinePublishedDatasourceNodeRunApi, + "/rag/pipelines//workflows/published/datasource/nodes//run", +) + +api.add_resource( + RagPipelineDrafDatasourceNodeRunApi, + "/rag/pipelines//workflows/draft/datasource/nodes//run", ) api.add_resource( diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 79925a1c1b..d899e89b02 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -414,27 +414,30 @@ class RagPipelineService: return workflow_node_execution def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool ) -> dict: """ Run published workflow datasource """ + if is_published: # fetch published workflow by app_model - published_workflow = self.get_published_workflow(pipeline=pipeline) - if not published_workflow: + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: raise ValueError("Workflow not initialized") # run draft workflow node datasource_node_data = None start_at = time.perf_counter() - datasource_nodes = published_workflow.graph_dict.get("nodes", []) + datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) for key, value in datasource_parameters.items(): if not user_inputs.get(key): @@ -651,7 +654,7 @@ class RagPipelineService: if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" ] return datasource_provider_variables - + def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get first step parameters of rag pipeline @@ -683,7 +686,7 @@ class RagPipelineService: if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): user_input_variables.append(variables_map.get(key, {})) return user_input_variables - + def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get first step parameters of rag pipeline From 0ff746ebf6f1fc01f20002a863babfac5dd4889d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 12:08:09 +0800 Subject: [PATCH 079/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 50 +++++++++++------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index d899e89b02..ac7df87586 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -450,32 +450,32 @@ class RagPipelineService: tenant_id=pipeline.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [page.model_dump() for page in online_document_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } - if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [page.model_dump() for page in online_document_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } - - elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [result.model_dump() for result in website_crawl_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } - else: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [result.model_dump() for result in website_crawl_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] From d2750f1a028a92732119d0e10fab8869c35cc31c Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 14:22:00 +0800 Subject: [PATCH 080/155] r2 --- .../website_crawl/website_crawl_provider.py | 2 + api/services/datasource_provider_service.py | 41 +++++++++++++ api/services/rag_pipeline/rag_pipeline.py | 61 +++++++++++-------- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 8c0f20ce2d..a65efb750e 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,3 +1,4 @@ +from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -43,6 +44,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") + return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 5bede09a64..64fa97197d 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -120,6 +120,47 @@ class DatasourceProviderService: return copy_credentials_list + def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider_id: provider id + :return: + """ + # Get all provider configurations of the current workspace + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + for datasource_provider in datasource_providers: + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) + copy_credentials_list.append( + { + "credentials": copy_credentials, + "type": datasource_provider.auth_type, + } + ) + + return copy_credentials_list + + def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: """ update datasource credentials. diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index d899e89b02..cb42224c60 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -43,6 +43,7 @@ from models.account import Account from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser +from models.oauth import DatasourceProvider from models.workflow import ( Workflow, WorkflowNodeExecutionTriggeredFrom, @@ -50,6 +51,7 @@ from models.workflow import ( WorkflowType, ) from services.dataset_service import DatasetService +from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, @@ -442,6 +444,7 @@ class RagPipelineService: for key, value in datasource_parameters.items(): if not user_inputs.get(key): user_inputs[key] = value["value"] + from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( @@ -450,32 +453,40 @@ class RagPipelineService: tenant_id=pipeline.tenant_id, datasource_type=DatasourceProviderType(datasource_type), ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get('provider_name'), + plugin_id=datasource_node_data.get('plugin_id'), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [page.model_dump() for page in online_document_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } - if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [page.model_dump() for page in online_document_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } - - elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [result.model_dump() for result in website_crawl_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } - else: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [result.model_dump() for result in website_crawl_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] From a82a9fb9d485222f81225270554c5904a7620701 Mon Sep 17 00:00:00 2001 From: twwu Date: Fri, 6 Jun 2025 11:43:00 +0800 Subject: [PATCH 081/155] fix: update condition for handling datasource selection in DataSourceOptions --- .../create-from-pipeline/data-source-options/index.tsx | 2 +- .../components/panel/test-run/data-source-options/index.tsx | 2 +- .../panel/test-run/data-source/local-file/file-uploader.tsx | 2 +- .../panel/test-run/data-source/website-crawl/base/crawler.tsx | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx index b56e393fca..d0a410f5e0 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx @@ -26,7 +26,7 @@ const DataSourceOptions = ({ }, [datasources, onSelect]) useEffect(() => { - if (options.length > 0) + if (options.length > 0 && !datasourceNodeId) handelSelect(options[0].value) // eslint-disable-next-line react-hooks/exhaustive-deps }, []) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx b/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx index afbba09594..f9a0b4c9b8 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx @@ -22,7 +22,7 @@ const DataSourceOptions = ({ }, [datasources, onSelect]) useEffect(() => { - if (options.length > 0) + if (options.length > 0 && !dataSourceNodeId) handelSelect(options[0].value) // eslint-disable-next-line react-hooks/exhaustive-deps }, []) diff --git a/web/app/components/rag-pipeline/components/panel/test-run/data-source/local-file/file-uploader.tsx b/web/app/components/rag-pipeline/components/panel/test-run/data-source/local-file/file-uploader.tsx index a8568266c4..125ba15983 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/data-source/local-file/file-uploader.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/data-source/local-file/file-uploader.tsx @@ -242,7 +242,7 @@ const FileUploader = ({ }, [handleDrop]) return ( -
+
{!hideUpload && ( +
{!isInit && ( -
+
{isRunning && ( Date: Fri, 6 Jun 2025 14:21:23 +0800 Subject: [PATCH 082/155] confirm publish --- .../rag-pipeline-header/publisher/popup.tsx | 73 ++++++++++++++----- .../workflow/nodes/data-source/constants.ts | 22 +++++- .../workflow/nodes/data-source/default.ts | 19 +++-- .../workflow/nodes/data-source/panel.tsx | 23 ++++-- .../workflow/nodes/data-source/types.ts | 5 +- web/i18n/en-US/pipeline.ts | 2 + web/i18n/zh-Hans/pipeline.ts | 2 + 7 files changed, 113 insertions(+), 33 deletions(-) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx index 85a1cf9266..3349192087 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx @@ -9,7 +9,10 @@ import { RiPlayCircleLine, RiTerminalBoxLine, } from '@remixicon/react' -import { useKeyPress } from 'ahooks' +import { + useBoolean, + useKeyPress, +} from 'ahooks' import { useTranslation } from 'react-i18next' import { useStore, @@ -29,6 +32,7 @@ import { useParams, useRouter } from 'next/navigation' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useInvalid } from '@/service/use-base' import { publishedPipelineInfoQueryKeyPrefix } from '@/service/use-pipeline' +import Confirm from '@/app/components/base/confirm' const PUBLISH_SHORTCUT = ['⌘', '⇧', 'P'] @@ -46,29 +50,52 @@ const Popup = () => { const { mutateAsync: publishWorkflow } = usePublishWorkflow() const { notify } = useToastContext() const workflowStore = useWorkflowStore() + const [confirmVisible, { + setFalse: hideConfirm, + setTrue: showConfirm, + }] = useBoolean(false) + const [publishing, { + setFalse: hidePublishing, + setTrue: showPublishing, + }] = useBoolean(false) const invalidPublishedPipelineInfo = useInvalid([...publishedPipelineInfoQueryKeyPrefix, pipelineId]) const handlePublish = useCallback(async (params?: PublishWorkflowParams) => { - if (await handleCheckBeforePublish()) { - const res = await publishWorkflow({ - url: `/rag/pipelines/${pipelineId}/workflows/publish`, - title: params?.title || '', - releaseNotes: params?.releaseNotes || '', - }) - setPublished(true) + if (publishing) + return + try { + const checked = await handleCheckBeforePublish() - if (res) { - notify({ type: 'success', message: t('common.api.actionSuccess') }) - workflowStore.getState().setPublishedAt(res.created_at) - mutateDatasetRes?.() - invalidPublishedPipelineInfo() + if (checked) { + if (!publishedAt && !confirmVisible) { + showConfirm() + return + } + showPublishing() + const res = await publishWorkflow({ + url: `/rag/pipelines/${pipelineId}/workflows/publish`, + title: params?.title || '', + releaseNotes: params?.releaseNotes || '', + }) + setPublished(true) + if (res) { + notify({ type: 'success', message: t('common.api.actionSuccess') }) + workflowStore.getState().setPublishedAt(res.created_at) + mutateDatasetRes?.() + invalidPublishedPipelineInfo() + } } } - else { - throw new Error('Checklist failed') + catch { } - }, [handleCheckBeforePublish, publishWorkflow, pipelineId, notify, t, workflowStore, mutateDatasetRes, invalidPublishedPipelineInfo]) + finally { + if (publishing) + hidePublishing() + if (confirmVisible) + hideConfirm() + } + }, [handleCheckBeforePublish, publishWorkflow, pipelineId, notify, t, workflowStore, mutateDatasetRes, invalidPublishedPipelineInfo, showConfirm, publishedAt, confirmVisible, hidePublishing, showPublishing, hideConfirm, publishing]) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.p`, (e) => { e.preventDefault() @@ -108,7 +135,7 @@ const Popup = () => { variant='primary' className='mt-3 w-full' onClick={() => handlePublish()} - disabled={published} + disabled={published || publishing} > { published @@ -163,6 +190,18 @@ const Popup = () => {
+ { + confirmVisible && ( + + ) + }
) } diff --git a/web/app/components/workflow/nodes/data-source/constants.ts b/web/app/components/workflow/nodes/data-source/constants.ts index c80b7d4f62..e18879be9e 100644 --- a/web/app/components/workflow/nodes/data-source/constants.ts +++ b/web/app/components/workflow/nodes/data-source/constants.ts @@ -30,7 +30,7 @@ export const COMMON_OUTPUT = [ }, ] -export const FILE_OUTPUT = [ +export const LOCAL_FILE_OUTPUT = [ { name: 'file', type: VarType.file, @@ -80,7 +80,7 @@ export const FILE_OUTPUT = [ }, ] -export const WEBSITE_OUTPUT = [ +export const WEBSITE_CRAWL_OUTPUT = [ { name: 'source_url', type: VarType.string, @@ -102,3 +102,21 @@ export const WEBSITE_OUTPUT = [ description: 'The description of the crawled website', }, ] + +export const ONLINE_DOCUMENT_OUTPUT = [ + { + name: 'workspace_id', + type: VarType.string, + description: 'The ID of the workspace where the document is stored', + }, + { + name: 'page_id', + type: VarType.string, + description: 'The ID of the page in the document', + }, + { + name: 'content', + type: VarType.string, + description: 'The content of the online document', + }, +] diff --git a/web/app/components/workflow/nodes/data-source/default.ts b/web/app/components/workflow/nodes/data-source/default.ts index bb5c9375b1..2319757d48 100644 --- a/web/app/components/workflow/nodes/data-source/default.ts +++ b/web/app/components/workflow/nodes/data-source/default.ts @@ -5,8 +5,9 @@ import { genNodeMetaData } from '@/app/components/workflow/utils' import { BlockEnum } from '@/app/components/workflow/types' import { COMMON_OUTPUT, - FILE_OUTPUT, - WEBSITE_OUTPUT, + LOCAL_FILE_OUTPUT, + ONLINE_DOCUMENT_OUTPUT, + WEBSITE_CRAWL_OUTPUT, } from './constants' import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' @@ -58,18 +59,24 @@ const nodeDefault: NodeDefault = { const { provider_type, } = payload - const isLocalFile = provider_type === DataSourceClassification.file - const isWebsiteCrawl = provider_type === DataSourceClassification.website + const isLocalFile = provider_type === DataSourceClassification.localFile + const isWebsiteCrawl = provider_type === DataSourceClassification.websiteCrawl + const isOnlineDocument = provider_type === DataSourceClassification.onlineDocument return [ ...COMMON_OUTPUT.map(item => ({ variable: item.name, type: item.type })), ...( isLocalFile - ? FILE_OUTPUT.map(item => ({ variable: item.name, type: item.type })) + ? LOCAL_FILE_OUTPUT.map(item => ({ variable: item.name, type: item.type })) : [] ), ...( isWebsiteCrawl - ? WEBSITE_OUTPUT.map(item => ({ variable: item.name, type: item.type })) + ? WEBSITE_CRAWL_OUTPUT.map(item => ({ variable: item.name, type: item.type })) + : [] + ), + ...( + isOnlineDocument + ? ONLINE_DOCUMENT_OUTPUT.map(item => ({ variable: item.name, type: item.type })) : [] ), ...ragVars, diff --git a/web/app/components/workflow/nodes/data-source/panel.tsx b/web/app/components/workflow/nodes/data-source/panel.tsx index c043de9df1..1148efc2bd 100644 --- a/web/app/components/workflow/nodes/data-source/panel.tsx +++ b/web/app/components/workflow/nodes/data-source/panel.tsx @@ -20,8 +20,9 @@ import { useNodesReadOnly } from '@/app/components/workflow/hooks' import { useConfig } from './hooks/use-config' import { COMMON_OUTPUT, - FILE_OUTPUT, - WEBSITE_OUTPUT, + LOCAL_FILE_OUTPUT, + ONLINE_DOCUMENT_OUTPUT, + WEBSITE_CRAWL_OUTPUT, } from './constants' import { useStore } from '@/app/components/workflow/store' import Button from '@/app/components/base/button' @@ -48,8 +49,9 @@ const Panel: FC> = ({ id, data }) => { handleFileExtensionsChange, handleParametersChange, } = useConfig(id) - const isLocalFile = provider_type === DataSourceClassification.file - const isWebsiteCrawl = provider_type === DataSourceClassification.website + const isLocalFile = provider_type === DataSourceClassification.localFile + const isWebsiteCrawl = provider_type === DataSourceClassification.websiteCrawl + const isOnlineDocument = provider_type === DataSourceClassification.onlineDocument const currentDataSource = dataSourceList?.find(ds => ds.plugin_id === plugin_id) const isAuthorized = !!currentDataSource?.is_authorized const [showAuthModal, { @@ -166,7 +168,7 @@ const Panel: FC> = ({ id, data }) => { )) } { - isLocalFile && FILE_OUTPUT.map(item => ( + isLocalFile && LOCAL_FILE_OUTPUT.map(item => ( > = ({ id, data }) => { )) } { - isWebsiteCrawl && WEBSITE_OUTPUT.map(item => ( + isWebsiteCrawl && WEBSITE_CRAWL_OUTPUT.map(item => ( + + )) + } + { + isOnlineDocument && ONLINE_DOCUMENT_OUTPUT.map(item => ( Date: Fri, 6 Jun 2025 15:06:26 +0800 Subject: [PATCH 083/155] r2 --- api/core/workflow/nodes/datasource/datasource_node.py | 5 +++-- docker/docker-compose.middleware.yaml | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 8c76fc161d..2782f2fb4c 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -103,17 +103,18 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "online_document": online_document_result.result.model_dump(), + **online_document_result.result.model_dump(), "datasource_type": datasource_type, }, ) case DatasourceProviderType.WEBSITE_CRAWL: + return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "website": datasource_info, + **datasource_info, "datasource_type": datasource_type, }, ) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index d4a0b94619..a89a834906 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -71,7 +71,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.1.1-local + image: langgenius/dify-plugin-daemon:deploy-dev-local restart: always env_file: - ./middleware.env @@ -121,6 +121,12 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} + THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true + THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" From d2d5fc62aeeba168d2e57249c2be4189b173984d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 15:19:53 +0800 Subject: [PATCH 084/155] r2 --- .../console/datasets/rag_pipeline/rag_pipeline_import.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index 853aef2e09..e5c211be93 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -111,7 +111,6 @@ class RagPipelineExportApi(Resource): @login_required @get_rag_pipeline @account_initialization_required - @marshal_with(pipeline_import_check_dependencies_fields) def get(self, pipeline: Pipeline): if not current_user.is_editor: raise Forbidden() From 4ffdf68a2099921ef79f44d38b52a2d18362aa76 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 16:03:35 +0800 Subject: [PATCH 085/155] r2 --- .../rag_pipeline/rag_pipeline_dsl_service.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 2e1ed57908..8787490555 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -197,8 +197,8 @@ class RagPipelineDslService: # Validate and fix DSL version if not data.get("version"): data["version"] = "0.1.0" - if not data.get("kind") or data.get("kind") != "rag-pipeline": - data["kind"] = "rag-pipeline" + if not data.get("kind") or data.get("kind") != "rag_pipeline": + data["kind"] = "rag_pipeline" imported_version = data.get("version", "0.1.0") # check if imported_version is a float-like string @@ -277,8 +277,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge_index": - knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) if ( dataset and pipeline.is_published @@ -412,8 +411,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge_index": - knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) if not dataset: dataset = Dataset( tenant_id=account.current_tenant_id, @@ -644,7 +642,7 @@ class RagPipelineDslService: export_data = { "version": CURRENT_DSL_VERSION, "kind": "rag_pipeline", - "pipeline": { + "rag_pipeline": { "name": pipeline.name, "icon": icon_info.get("icon", "📙") if icon_info else "📙", "icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji", From 1aa13bd20d5f382ce5e3b47e509b2e1f410d0e36 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 16:05:49 +0800 Subject: [PATCH 086/155] r2 --- .../rag_pipeline/rag_pipeline_dsl_service.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 2e1ed57908..8787490555 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -197,8 +197,8 @@ class RagPipelineDslService: # Validate and fix DSL version if not data.get("version"): data["version"] = "0.1.0" - if not data.get("kind") or data.get("kind") != "rag-pipeline": - data["kind"] = "rag-pipeline" + if not data.get("kind") or data.get("kind") != "rag_pipeline": + data["kind"] = "rag_pipeline" imported_version = data.get("version", "0.1.0") # check if imported_version is a float-like string @@ -277,8 +277,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge_index": - knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) if ( dataset and pipeline.is_published @@ -412,8 +411,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge_index": - knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) if not dataset: dataset = Dataset( tenant_id=account.current_tenant_id, @@ -644,7 +642,7 @@ class RagPipelineDslService: export_data = { "version": CURRENT_DSL_VERSION, "kind": "rag_pipeline", - "pipeline": { + "rag_pipeline": { "name": pipeline.name, "icon": icon_info.get("icon", "📙") if icon_info else "📙", "icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji", From 55e20d189afc5adeb9c65a56c8f7e347be8f5548 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 6 Jun 2025 16:16:44 +0800 Subject: [PATCH 087/155] Update deploy-dev.yml --- .github/workflows/deploy-dev.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 47ca03c2eb..409a80f19e 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -4,7 +4,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "deploy/dev" + - "deploy/rag-dev" types: - completed @@ -17,8 +17,8 @@ jobs: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 with: - host: ${{ secrets.SSH_HOST }} + host: ${{ secrets.RAG_SSH_HOST }} username: ${{ secrets.SSH_USER }} - key: ${{ secrets.SSH_PRIVATE_KEY }} + key: ${{ secrets.RAG_SSH_PRIVATE_KEY }} script: | ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} From 21a3509bef52e2362ddc398fe11de5b29992023b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 17:14:43 +0800 Subject: [PATCH 088/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 88 +++++++++++++++++++ .../entities/datasource_entities.py | 4 +- api/services/rag_pipeline/rag_pipeline.py | 61 +++++++++++++ docker/docker-compose.middleware.yaml | 2 +- 4 files changed, 153 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fef10b79a7..fa4020d7db 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -300,6 +300,86 @@ class PublishedRagPipelineRunApi(Resource): except InvokeRateLimitError as ex: raise InvokeRateLimitHttpError(ex.description) +class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + job_id = args.get("job_id") + if job_id == None: + raise ValueError("missing job_id") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node_status( + pipeline=pipeline, + node_id=node_id, + job_id=job_id, + account=current_user, + datasource_type=datasource_type, + is_published=True + ) + + return result + +class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + job_id = args.get("job_id") + if job_id == None: + raise ValueError("missing job_id") + datasource_type = args.get("datasource_type") + if datasource_type == None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node_status( + pipeline=pipeline, + node_id=node_id, + job_id=job_id, + account=current_user, + datasource_type=datasource_type, + is_published=False + ) + + return result + class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @@ -894,6 +974,14 @@ api.add_resource( RagPipelinePublishedDatasourceNodeRunApi, "/rag/pipelines//workflows/published/datasource/nodes//run", ) +api.add_resource( + RagPipelinePublishedDatasourceNodeRunStatusApi, + "/rag/pipelines//workflows/published/datasource/nodes//run-status", +) +api.add_resource( + RagPipelineDraftDatasourceNodeRunStatusApi, + "/rag/pipelines//workflows/draft/datasource/nodes//run-status", +) api.add_resource( RagPipelineDrafDatasourceNodeRunApi, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 6a9fc5d9f9..647d8f9a8c 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -304,4 +304,6 @@ class GetWebsiteCrawlResponse(BaseModel): Get website crawl response """ - result: list[WebSiteInfo] + result: Optional[list[WebSiteInfo]] = [] + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index cb42224c60..1d4e279d2a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -415,6 +415,67 @@ class RagPipelineService: return workflow_node_execution + def run_datasource_workflow_node_status( + self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool + ) -> dict: + """ + Run published workflow datasource + """ + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + start_at = time.perf_counter() + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get('provider_name'), + plugin_id=datasource_node_data.get('plugin_id'), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters={"job_id": job_id}, + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": [result.model_dump() for result in website_crawl_result.result], + "job_id": website_crawl_result.job_id, + "status": website_crawl_result.status, + "provider_type": datasource_node_data.get("provider_type"), + } + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + + def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool ) -> dict: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index a89a834906..6bd0e554ab 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -94,7 +94,6 @@ services: PLUGIN_REMOTE_INSTALLING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_REMOTE_INSTALLING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_WORKING_PATH: ${PLUGIN_WORKING_PATH:-/app/storage/cwd} - FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} @@ -127,6 +126,7 @@ services: VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem + FORCE_VERIFYING_SIGNATURE: false ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" From fcbd5febeb7851e1444a653e175c85958ea17546 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 6 Jun 2025 17:47:06 +0800 Subject: [PATCH 089/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1d4e279d2a..7f5d4de7e3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -422,7 +422,7 @@ class RagPipelineService: Run published workflow datasource """ if is_published: - # fetch published workflow by app_model + # fetch published workflow by app_model workflow = self.get_published_workflow(pipeline=pipeline) else: workflow = self.get_draft_workflow(pipeline=pipeline) @@ -440,7 +440,6 @@ class RagPipelineService: if not datasource_node_data: raise ValueError("Datasource node data not found") - from core.datasource.datasource_manager import DatasourceManager datasource_runtime = DatasourceManager.get_datasource_runtime( @@ -474,16 +473,16 @@ class RagPipelineService: } case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, + is_published: bool ) -> dict: """ Run published workflow datasource """ if is_published: - # fetch published workflow by app_model + # fetch published workflow by app_model workflow = self.get_published_workflow(pipeline=pipeline) else: workflow = self.get_draft_workflow(pipeline=pipeline) @@ -544,6 +543,8 @@ class RagPipelineService: ) return { "result": [result.model_dump() for result in website_crawl_result.result], + "job_id": website_crawl_result.job_id, + "status": website_crawl_result.status, "provider_type": datasource_node_data.get("provider_type"), } case _: From 47664f8fd395bdb3f9be58ed300d71c63aa765c4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 9 Jun 2025 14:00:34 +0800 Subject: [PATCH 090/155] r2 --- .../entities/datasource_entities.py | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 647d8f9a8c..4aa8bf75c8 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -213,15 +213,6 @@ class GetOnlineDocumentPagesRequest(BaseModel): """ -class OnlineDocumentPageIcon(BaseModel): - """ - Online document page icon - """ - - type: str = Field(..., description="The type of the icon") - url: str = Field(..., description="The url of the icon") - - class OnlineDocumentPage(BaseModel): """ Online document page @@ -229,7 +220,7 @@ class OnlineDocumentPage(BaseModel): page_id: str = Field(..., description="The page id") page_title: str = Field(..., description="The page title") - page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon") + page_icon: Optional[dict] = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") @@ -288,22 +279,27 @@ class GetWebsiteCrawlRequest(BaseModel): crawl_parameters: dict = Field(..., description="The crawl parameters") -class WebSiteInfo(BaseModel): - """ - Website info - """ - +class WebSiteInfoDetail(BaseModel): source_url: str = Field(..., description="The url of the website") content: str = Field(..., description="The content of the website") title: str = Field(..., description="The title of the website") description: str = Field(..., description="The description of the website") +class WebSiteInfo(BaseModel): + """ + Website info + """ + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + + + class GetWebsiteCrawlResponse(BaseModel): """ Get website crawl response """ - result: Optional[list[WebSiteInfo]] = [] - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") + result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) + From ad3d9cf78261539f60b939353199bee59780de98 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 10:00:20 +0800 Subject: [PATCH 091/155] r2 --- api/core/datasource/entities/datasource_entities.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 4aa8bf75c8..adcdcccf83 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -285,6 +285,7 @@ class WebSiteInfoDetail(BaseModel): title: str = Field(..., description="The title of the website") description: str = Field(..., description="The description of the website") + class WebSiteInfo(BaseModel): """ Website info @@ -294,12 +295,9 @@ class WebSiteInfo(BaseModel): web_info_list: Optional[list[WebSiteInfoDetail]] = [] - - class GetWebsiteCrawlResponse(BaseModel): """ Get website crawl response """ result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) - From c0d3452494e9583795f04837970beb736ea39c8d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 10:59:44 +0800 Subject: [PATCH 092/155] r2 --- .../console/datasets/rag_pipeline/rag_pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 1b869d9847..0d882e29e8 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -111,7 +111,7 @@ class CustomizedPipelineTemplateApi(Resource): return {"data": dsl}, 200 -class CustomizedPipelineTemplateApi(Resource): +class PublishCustomizedPipelineTemplateApi(Resource): @setup_required @login_required @account_initialization_required @@ -140,7 +140,7 @@ class CustomizedPipelineTemplateApi(Resource): ) args = parser.parse_args() rag_pipeline_service = RagPipelineService() - RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) + rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) return 200 @@ -156,3 +156,7 @@ api.add_resource( CustomizedPipelineTemplateApi, "/rag/pipeline/customized/templates/", ) +api.add_resource( + CustomizedPipelineTemplateApi, + "/rag/pipeline/customized/templates//publish", +) From 65873aa411ee5348005b60d0ef67f449b568f93f Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 11:44:52 +0800 Subject: [PATCH 093/155] r2 --- api/controllers/console/datasets/rag_pipeline/rag_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 0d882e29e8..3be655a105 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -157,6 +157,6 @@ api.add_resource( "/rag/pipeline/customized/templates/", ) api.add_resource( - CustomizedPipelineTemplateApi, + PublishCustomizedPipelineTemplateApi, "/rag/pipeline/customized/templates//publish", ) From 4d967544f3a427638b19c9743807acb65bc58193 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 14:13:10 +0800 Subject: [PATCH 094/155] r2 --- api/controllers/console/datasets/rag_pipeline/rag_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 3be655a105..b76071229b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -158,5 +158,5 @@ api.add_resource( ) api.add_resource( PublishCustomizedPipelineTemplateApi, - "/rag/pipeline/customized/templates//publish", + "/rag/pipeline/customized/pipelines//publish", ) From a7ff2ab470274725120378e40bba4d9040881498 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 14:53:07 +0800 Subject: [PATCH 095/155] r2 --- .../console/datasets/rag_pipeline/rag_pipeline.py | 2 +- .../datasets/rag_pipeline/rag_pipeline_workflow.py | 9 ++++++++- api/services/rag_pipeline/rag_pipeline.py | 6 +++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index b76071229b..033a8e5483 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -158,5 +158,5 @@ api.add_resource( ) api.add_resource( PublishCustomizedPipelineTemplateApi, - "/rag/pipeline/customized/pipelines//publish", + "/rag/pipeline//customized/publish", ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fa4020d7db..7b8adfe560 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -300,6 +300,7 @@ class PublishedRagPipelineRunApi(Resource): except InvokeRateLimitError as ex: raise InvokeRateLimitHttpError(ex.description) + class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): @setup_required @login_required @@ -340,6 +341,7 @@ class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): return result + class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): @setup_required @login_required @@ -379,7 +381,7 @@ class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): ) return result - + class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @@ -421,6 +423,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): return result + class RagPipelineDrafDatasourceNodeRunApi(Resource): @setup_required @login_required @@ -461,6 +464,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): return result + class RagPipelinePublishedNodeRunApi(Resource): @setup_required @login_required @@ -798,6 +802,7 @@ class PublishedRagPipelineSecondStepApi(Resource): "variables": variables, } + class PublishedRagPipelineFirstStepApi(Resource): @setup_required @login_required @@ -822,6 +827,7 @@ class PublishedRagPipelineFirstStepApi(Resource): "variables": variables, } + class DraftRagPipelineFirstStepApi(Resource): @setup_required @login_required @@ -846,6 +852,7 @@ class DraftRagPipelineFirstStepApi(Resource): "variables": variables, } + class DraftRagPipelineSecondStepApi(Resource): @setup_required @login_required diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 7f5d4de7e3..0d5786ddda 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -542,9 +542,9 @@ class RagPipelineService: provider_type=datasource_runtime.datasource_provider_type(), ) return { - "result": [result.model_dump() for result in website_crawl_result.result], - "job_id": website_crawl_result.job_id, - "status": website_crawl_result.status, + "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], + "job_id": website_crawl_result.result.job_id, + "status": website_crawl_result.result.status, "provider_type": datasource_node_data.get("provider_type"), } case _: From 7624edd32d5b67528881d774c34889641904b12c Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 14:56:18 +0800 Subject: [PATCH 096/155] r2 --- api/controllers/console/datasets/rag_pipeline/rag_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 033a8e5483..a9ffc42cba 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -158,5 +158,5 @@ api.add_resource( ) api.add_resource( PublishCustomizedPipelineTemplateApi, - "/rag/pipeline//customized/publish", + "/rag/pipelines//customized/publish", ) From 58b5daeef35c7d06d48f3f9f50f85655275a29f8 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 15:56:28 +0800 Subject: [PATCH 097/155] r2 --- api/controllers/console/datasets/rag_pipeline/rag_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 033a8e5483..97d9fa5967 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -141,7 +141,7 @@ class PublishCustomizedPipelineTemplateApi(Resource): args = parser.parse_args() rag_pipeline_service = RagPipelineService() rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) - return 200 + return {"result": "success"} api.add_resource( @@ -158,5 +158,5 @@ api.add_resource( ) api.add_resource( PublishCustomizedPipelineTemplateApi, - "/rag/pipeline//customized/publish", + "/rag/pipelines//customized/publish", ) From 80b219707edb47a385e2ee6aed828e57e350e711 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 17:11:49 +0800 Subject: [PATCH 098/155] r2 --- .../datasets/rag_pipeline/rag_pipeline.py | 4 ++- .../customized/customized_retrieval.py | 1 - api/services/rag_pipeline/rag_pipeline.py | 34 ++++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 97d9fa5967..f2c0870f72 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -51,7 +51,9 @@ class PipelineTemplateDetailApi(Resource): @account_initialization_required @enterprise_license_required def get(self, template_id: str): - pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id) + type = request.args.get("type", default="built-in", type=str) + rag_pipeline_service = RagPipelineService() + pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) return pipeline_template, 200 diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 3ede75309d..d655dc93a1 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -64,7 +64,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): pipeline_template = ( db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() ) - if not pipeline_template: return None diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0d5786ddda..abbc269cec 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -8,7 +8,7 @@ from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user -from sqlalchemy import or_, select +from sqlalchemy import func, or_, select from sqlalchemy.orm import Session import contexts @@ -78,15 +78,20 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]: """ Get pipeline template detail. :param template_id: template id :return: """ - mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + if type == "built-in": + mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + else: + mode = "customized" + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) return result @classmethod @@ -930,5 +935,24 @@ class RagPipelineService: workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() + + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) + + pipeline_customized_template = PipelineCustomizedTemplate( + name=args.get("name"), + description=args.get("description"), + icon=args.get("icon_info"), + tenant_id=pipeline.tenant_id, + yaml_content=dsl, + position=max_position + 1 if max_position else 1, + chunk_structure=dataset.chunk_structure, + language="en-US", + ) + db.session.add(pipeline_customized_template) db.session.commit() From aeb1d1946cade60ca2edddeea574cd2525c43fce Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 17:59:14 +0800 Subject: [PATCH 099/155] r2 --- api/fields/dataset_fields.py | 2 ++ api/models/dataset.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 2871b3ec16..aa147331d4 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -88,6 +88,8 @@ dataset_detail_fields = { "chunk_structure": fields.String, "icon_info": fields.Nested(icon_info_fields), "is_published": fields.Boolean, + "total_documents": fields.Integer, + "total_available_documents": fields.Integer, } dataset_query_detail_fields = { diff --git a/api/models/dataset.py b/api/models/dataset.py index d2fdd5e900..6795f719df 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -312,6 +312,19 @@ class DatasetProcessRule(Base): except JSONDecodeError: return None + @property + def total_documents(self): + return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.dataset_id).scalar() + + @property + def total_available_documents(self): + return db.session.query(func.count(Document.id)).filter( + Document.dataset_id == self.dataset_id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ).scalar() + class Document(Base): __tablename__ = "documents" From e934503fa05c62e0423397f4fe565ffbb24f1ce9 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 18:16:30 +0800 Subject: [PATCH 100/155] r2 --- api/controllers/console/datasets/datasets.py | 9 +++++++++ .../nodes/knowledge_index/knowledge_index_node.py | 15 ++++++++++++--- api/fields/dataset_fields.py | 2 ++ api/models/dataset.py | 13 +++++++++++++ api/services/dataset_service.py | 3 +++ 5 files changed, 39 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index e68273afa6..ceaa9ec4fa 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -283,6 +283,15 @@ class DatasetApi(Resource): location="json", help="Invalid external knowledge api id.", ) + + parser.add_argument( + "icon_info", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid icon info.", + ) args = parser.parse_args() data = request.get_json() diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 1195df1b7f..c63d837106 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,18 +1,21 @@ import datetime import logging from collections.abc import Mapping +import time from typing import Any, cast +from sqlalchemy import func + from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from ..base import BaseNode from .entities import KnowledgeIndexNodeData @@ -111,13 +114,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): document = db.session.query(Document).filter_by(id=document_id.value).first() if not document: raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.") - + # chunk nodes by chunk size + indexing_start_at = time.perf_counter() index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) - + indexing_end_at = time.perf_counter() + document.indexing_latency = indexing_end_at - indexing_start_at # update document status document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.word_count = db.session.query(func.sum(DocumentSegment.word_count)).filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ).scalar() db.session.add(document) # update document segment status db.session.query(DocumentSegment).filter( diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 2871b3ec16..aa147331d4 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -88,6 +88,8 @@ dataset_detail_fields = { "chunk_structure": fields.String, "icon_info": fields.Nested(icon_info_fields), "is_published": fields.Boolean, + "total_documents": fields.Integer, + "total_available_documents": fields.Integer, } dataset_query_detail_fields = { diff --git a/api/models/dataset.py b/api/models/dataset.py index d2fdd5e900..85c10c06d7 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -311,6 +311,19 @@ class DatasetProcessRule(Base): return json.loads(self.rules) if self.rules else None except JSONDecodeError: return None + + @property + def total_documents(self): + return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.dataset_id).scalar() + + @property + def total_available_documents(self): + return db.session.query(func.count(Document.id)).filter( + Document.dataset_id == self.dataset_id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ).scalar() class Document(Base): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 133e3765f7..ab16081afc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -484,6 +484,9 @@ class DatasetService: # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] + # update icon info + if data.get("icon_info"): + filtered_data["icon_info"] = data.get("icon_info") db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) db.session.commit() From 95a24156de4e2b5cd23ca0041677e8ef68ecd745 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 18:20:32 +0800 Subject: [PATCH 101/155] r2 --- api/core/rag/index_processor/constant/built_in_field.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index c8ad53e3dd..05fbf9003b 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -13,3 +13,5 @@ class MetadataDataSource(Enum): upload_file = "file_upload" website_crawl = "website" notion_import = "notion" + local_file = "file_upload" + online_document = "online_document" From 127a77d80784da6219fe3b1567fbe969e9238a3f Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 10 Jun 2025 19:22:08 +0800 Subject: [PATCH 102/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index abbc269cec..a5c6021df3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session import contexts from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPagesResponse, @@ -28,6 +29,7 @@ from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) +from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base.node import BaseNode @@ -40,7 +42,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore +from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.oauth import DatasourceProvider @@ -678,6 +680,20 @@ class RagPipelineService: # create workflow node execution workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED workflow_node_execution.error = error + # update document status + variable_pool = node_instance.graph_runtime_state.variable_pool + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + if invoke_from: + if invoke_from.value == InvokeFrom.PUBLISHED.value: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter(Document.id == document_id.value).first() + if document: + document.indexing_status = "error" + document.error = error + db.session.add(document) + db.session.commit() + return workflow_node_execution From 2972a06f16ab0b1ac2a9ee72c9e1a9e4c797c288 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 11:21:17 +0800 Subject: [PATCH 103/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a5c6021df3..f9bd5bbc51 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -104,7 +104,7 @@ class RagPipelineService: :param template_info: template info """ customized_template: PipelineCustomizedTemplate | None = ( - db.query(PipelineCustomizedTemplate) + db.session.query(PipelineCustomizedTemplate) .filter( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, @@ -116,7 +116,7 @@ class RagPipelineService: customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() - db.commit() + db.session.commit() return customized_template @classmethod From 874e1bc41d171431fa10449f2cf5daace532156e Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 13:12:18 +0800 Subject: [PATCH 104/155] r2 --- ...1_1155-224fba149d48_add_pipeline_info_6.py | 43 +++++++++++++++++++ api/models/dataset.py | 18 ++++++++ .../customized/customized_retrieval.py | 5 ++- .../database/database_retrieval.py | 4 +- api/services/rag_pipeline/rag_pipeline.py | 3 +- 5 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py diff --git a/api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py b/api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py new file mode 100644 index 0000000000..d2cd61f9ec --- /dev/null +++ b/api/migrations/versions/2025_06_11_1155-224fba149d48_add_pipeline_info_6.py @@ -0,0 +1,43 @@ +"""add_pipeline_info_6 + +Revision ID: 224fba149d48 +Revises: d466c551816f +Create Date: 2025-06-11 11:55:01.179201 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '224fba149d48' +down_revision = 'd466c551816f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.drop_column('updated_by') + batch_op.drop_column('created_by') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 85c10c06d7..da6e58f113 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1189,6 +1189,15 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] language = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = db.Column(StringUUID, nullable=False) + updated_by = db.Column(StringUUID, nullable=True) + + @property + def created_user_name(self): + account = db.session.query(Account).filter(Account.id == self.created_by).first() + if account: + return account.name + return "" class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] @@ -1208,9 +1217,18 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] yaml_content = db.Column(db.Text, nullable=False) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False) + created_by = db.Column(StringUUID, nullable=False) + updated_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property + def created_user_name(self): + account = db.session.query(Account).filter(Account.id == self.created_by).first() + if account: + return account.name + return "" + class Pipeline(Base): # type: ignore[name-defined] __tablename__ = "pipelines" diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index d655dc93a1..5d0130bb8e 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -70,6 +70,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return { "id": pipeline_template.id, "name": pipeline_template.name, - "icon": pipeline_template.icon, + "icon_info": pipeline_template.icon, + "description": pipeline_template.description, + "chunk_structure": pipeline_template.chunk_structure, "export_data": yaml.safe_load(pipeline_template.yaml_content), + "created_by": pipeline_template.created_user_name, } diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 741384afc2..91bc9f6cdd 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -70,7 +70,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return { "id": pipeline_template.id, "name": pipeline_template.name, - "icon": pipeline_template.icon, + "icon_info": pipeline_template.icon, "chunk_structure": pipeline_template.chunk_structure, "export_data": yaml.safe_load(pipeline_template.yaml_content), + "created_by": pipeline_template.created_user_name, + "description": pipeline_template.description, } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f9bd5bbc51..f2ffce99b3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -116,6 +116,7 @@ class RagPipelineService: customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() + customized_template.updated_by = current_user.id db.session.commit() return customized_template @@ -694,7 +695,6 @@ class RagPipelineService: db.session.add(document) db.session.commit() - return workflow_node_execution def update_workflow( @@ -969,6 +969,7 @@ class RagPipelineService: position=max_position + 1 if max_position else 1, chunk_structure=dataset.chunk_structure, language="en-US", + created_by=current_user.id, ) db.session.add(pipeline_customized_template) db.session.commit() From a6f7560d2fe396856ff5786db61d10b367fabc6d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 14:03:32 +0800 Subject: [PATCH 105/155] r2 --- .../pipeline_template/customized/customized_retrieval.py | 2 +- .../pipeline_template/database/database_retrieval.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 5d0130bb8e..bb4db81a70 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -73,6 +73,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "icon_info": pipeline_template.icon, "description": pipeline_template.description, "chunk_structure": pipeline_template.chunk_structure, - "export_data": yaml.safe_load(pipeline_template.yaml_content), + "export_data": pipeline_template.yaml_content, "created_by": pipeline_template.created_user_name, } diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 91bc9f6cdd..c296111f46 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -72,7 +72,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "name": pipeline_template.name, "icon_info": pipeline_template.icon, "chunk_structure": pipeline_template.chunk_structure, - "export_data": yaml.safe_load(pipeline_template.yaml_content), + "export_data": pipeline_template.yaml_content, "created_by": pipeline_template.created_user_name, "description": pipeline_template.description, } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f2ffce99b3..429c53d829 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -474,9 +474,9 @@ class RagPipelineService: provider_type=datasource_runtime.datasource_provider_type(), ) return { - "result": [result.model_dump() for result in website_crawl_result.result], - "job_id": website_crawl_result.job_id, - "status": website_crawl_result.status, + "result": [result for result in website_crawl_result.result], + "job_id": website_crawl_result.result.job_id, + "status": website_crawl_result.result.status, "provider_type": datasource_node_data.get("provider_type"), } case _: From 66fa68fa18fd19e7b2ed62ae8e27d33b3df5e36a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 16:36:36 +0800 Subject: [PATCH 106/155] r2 --- api/core/app/apps/common/workflow_response_converter.py | 7 +++++++ .../pipeline_template/customized/customized_retrieval.py | 4 ++++ .../pipeline_template/database/database_retrieval.py | 6 ++++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6f524a5872..99bb597721 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -43,10 +43,12 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.file import FILE_MODEL_IDENTITY, File +from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.tool_manager import ToolManager from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.nodes import NodeType +from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.tool.entities import ToolNodeData from models import ( Account, @@ -181,6 +183,11 @@ class WorkflowResponseConverter: provider_type=node_data.provider_type, provider_id=node_data.provider_id, ) + elif event.node_type == NodeType.DATASOURCE: + node_data = cast(DatasourceNodeData, event.node_data) + manager = PluginDatasourceManager() + provider_entity = manager.fetch_datasource_provider(self._application_generate_entity.app_config.tenant_id, f"{node_data.plugin_id}/{node_data.provider_name}") + response.data.extras["icon"] = provider_entity.declaration.identity.icon return response diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index bb4db81a70..ca94f7f47a 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -66,6 +66,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ) if not pipeline_template: return None + + dsl_data = yaml.safe_load(pipeline_template.yaml_content) + graph_data = dsl_data.get("workflow", {}).get("graph", {}) return { "id": pipeline_template.id, @@ -74,5 +77,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "description": pipeline_template.description, "chunk_structure": pipeline_template.chunk_structure, "export_data": pipeline_template.yaml_content, + "graph": graph_data, "created_by": pipeline_template.created_user_name, } diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index c296111f46..b69a857c3a 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -66,13 +66,15 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - + dsl_data = yaml.safe_load(pipeline_template.yaml_content) + graph_data = dsl_data.get("workflow", {}).get("graph", {}) return { "id": pipeline_template.id, "name": pipeline_template.name, "icon_info": pipeline_template.icon, + "description": pipeline_template.description, "chunk_structure": pipeline_template.chunk_structure, "export_data": pipeline_template.yaml_content, + "graph": graph_data, "created_by": pipeline_template.created_user_name, - "description": pipeline_template.description, } From 5f08a9314ca0920577a70ca9161528bc55d4d10c Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 17:10:20 +0800 Subject: [PATCH 107/155] r2 --- api/core/plugin/impl/datasource.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 775f36b506..98ee0bb11e 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -242,7 +242,7 @@ class PluginDatasourceManager(BasePluginClient): "author": "langgenius", "name": "file", "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, - "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", + "icon": "https://assets.dify.ai/images/File%20Upload.svg", "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, }, "credentials_schema": [], diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 429c53d829..2ed44c4a36 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -126,7 +126,7 @@ class RagPipelineService: Delete customized pipeline template. """ customized_template: PipelineCustomizedTemplate | None = ( - db.query(PipelineCustomizedTemplate) + db.session.query(PipelineCustomizedTemplate) .filter( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, @@ -135,8 +135,8 @@ class RagPipelineService: ) if not customized_template: raise ValueError("Customized pipeline template not found.") - db.delete(customized_template) - db.commit() + db.session.delete(customized_template) + db.session.commit() def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: """ From 9eb8597957334a474b3f016bb4bbb5866867d969 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 17:29:14 +0800 Subject: [PATCH 108/155] r2 --- api/models/dataset.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/api/models/dataset.py b/api/models/dataset.py index da6e58f113..5d18eaff49 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -69,6 +69,19 @@ class Dataset(Base): pipeline_id = db.Column(StringUUID, nullable=True) chunk_structure = db.Column(db.String(255), nullable=True) + @property + def total_documents(self): + return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + + @property + def total_available_documents(self): + return db.session.query(func.count(Document.id)).filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ).scalar() + @property def dataset_keyword_table(self): dataset_keyword_table = ( @@ -311,20 +324,6 @@ class DatasetProcessRule(Base): return json.loads(self.rules) if self.rules else None except JSONDecodeError: return None - - @property - def total_documents(self): - return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.dataset_id).scalar() - - @property - def total_available_documents(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.dataset_id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - ).scalar() - class Document(Base): __tablename__ = "documents" From 14dc3e86427ecac1f6bd67c9d09c5b8e54d2753a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 18:03:21 +0800 Subject: [PATCH 109/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 6 +++--- api/services/rag_pipeline/rag_pipeline_dsl_service.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 2ed44c4a36..df9fea805c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -50,7 +50,7 @@ from models.workflow import ( Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowType, + WorkflowType, WorkflowNodeExecutionModel, ) from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService @@ -911,7 +911,7 @@ class RagPipelineService: pipeline: Pipeline, run_id: str, user: Account | EndUser, - ) -> list[WorkflowNodeExecution]: + ) -> list[WorkflowNodeExecutionModel]: """ Get workflow run node execution list """ @@ -930,7 +930,7 @@ class RagPipelineService: # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run( + node_executions = repository.get_db_models_by_workflow_run( workflow_run_id=run_id, order_config=order_config, triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 8787490555..fb311482d8 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -276,7 +276,7 @@ class RagPipelineDslService: nodes = graph.get("nodes", []) dataset_id = None for node in nodes: - if node.get("data", {}).get("type") == "knowledge_index": + if node.get("data", {}).get("type") == "knowledge-index": knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) if ( dataset From da3a3ce165337df6694afd4d6359c9ad5fdc885d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 11 Jun 2025 18:07:06 +0800 Subject: [PATCH 110/155] r2 --- api/controllers/console/datasets/rag_pipeline/rag_pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index f2c0870f72..23d402f914 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -109,8 +109,7 @@ class CustomizedPipelineTemplateApi(Resource): if not template: raise ValueError("Customized pipeline template not found.") - dsl = yaml.safe_load(template.yaml_content) - return {"data": dsl}, 200 + return {"data": template.yaml_content}, 200 class PublishCustomizedPipelineTemplateApi(Resource): From dee7b6eb22106c924e26d26817a7c69807447a02 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Thu, 12 Jun 2025 16:19:12 +0800 Subject: [PATCH 111/155] Update deploy-dev.yml --- .github/workflows/deploy-dev.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 409a80f19e..12bf6cfbc9 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -19,6 +19,6 @@ jobs: with: host: ${{ secrets.RAG_SSH_HOST }} username: ${{ secrets.SSH_USER }} - key: ${{ secrets.RAG_SSH_PRIVATE_KEY }} + key: ${{ secrets.SSH_PRIVATE_KEY }} script: | ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} From 7bd2509ad593227fdb2fc27c9dc48100ae21dddd Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 13 Jun 2025 14:50:38 +0800 Subject: [PATCH 112/155] Update deploy-dev.yml --- .github/workflows/deploy-dev.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 12bf6cfbc9..0d99c6fa58 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -12,7 +12,8 @@ jobs: deploy: runs-on: ubuntu-latest if: | - github.event.workflow_run.conclusion == 'success' + github.event.workflow_run.conclusion == 'success' && + github.event.workflow_run.head_branch == 'deploy/rag-dev' steps: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 From b2b95412b9ab7c94fc9b9a36ebe4bc089133e503 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 13 Jun 2025 15:04:22 +0800 Subject: [PATCH 113/155] r2 --- .../entities/datasource_entities.py | 28 ++- .../online_document/online_document_plugin.py | 7 +- .../datasource/utils/message_transformer.py | 6 +- api/core/file/enums.py | 1 + api/core/plugin/impl/datasource.py | 30 ++- .../nodes/datasource/datasource_node.py | 180 ++++++++++++++++-- api/factories/file_factory.py | 47 +++++ 7 files changed, 253 insertions(+), 46 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index adcdcccf83..8d68c80c81 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum +from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage class DatasourceProviderType(enum.StrEnum): @@ -301,3 +301,29 @@ class GetWebsiteCrawlResponse(BaseModel): """ result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) + + +class DatasourceInvokeMessage(ToolInvokeMessage): + """ + Datasource Invoke Message. + """ + + class WebsiteCrawlMessage(BaseModel): + """ + Website crawl message + """ + + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + + class OnlineDocumentMessage(BaseModel): + """ + Online document message + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") \ No newline at end of file diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index f94031656e..2ab60cae1e 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Generator from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, @@ -38,7 +39,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetOnlineDocumentPagesResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_pages( @@ -56,7 +57,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> GetOnlineDocumentPageContentResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_page_content( diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index a10030d93b..bd99387e8d 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -39,7 +39,7 @@ class DatasourceFileMessageTransformer: conversation_id=conversation_id, ) - url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" + url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" yield DatasourceInvokeMessage( type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, @@ -77,7 +77,7 @@ class DatasourceFileMessageTransformer: filename=filename, ) - url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type)) # check if file is image if "image" in mimetype: @@ -98,7 +98,7 @@ class DatasourceFileMessageTransformer: if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: yield DatasourceInvokeMessage( type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/file/enums.py b/api/core/file/enums.py index a50a651dd3..170eb4fc23 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" + DATASOURCE_FILE = "datasource_file" @staticmethod def value_of(value): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 98ee0bb11e..83b1a5760b 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,7 +1,8 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Generator from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, @@ -93,7 +94,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -103,7 +104,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", - GetWebsiteCrawlResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -118,10 +119,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def get_online_document_pages( self, @@ -132,7 +130,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetOnlineDocumentPagesResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -142,7 +140,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", - GetOnlineDocumentPagesResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -157,10 +155,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def get_online_document_page_content( self, @@ -171,7 +166,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> GetOnlineDocumentPageContentResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -181,7 +176,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", - GetOnlineDocumentPageContentResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -196,10 +191,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 2782f2fb4c..bd4a6e3a56 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,13 +1,18 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Generator, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.file import File from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError @@ -19,8 +24,11 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.tool.exc import ToolFileError from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from models.model import UploadFile from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey @@ -36,7 +44,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): _node_data_cls = DatasourceNodeData _node_type = NodeType.DATASOURCE - def _run(self) -> NodeRunResult: + def _run(self) -> Generator: """ Run the datasource node """ @@ -65,13 +73,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_type=DatasourceProviderType.value_of(datasource_type), ) except DatasourceNodeError as e: - return NodeRunResult( + yield RunCompletedEvent( + run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to get datasource runtime: {str(e)}", error_type=type(e).__name__, ) + ) # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -91,25 +101,22 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPageContentResponse = ( + online_document_result: Generator[DatasourceInvokeMessage, None, None] = ( datasource_runtime._get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), provider_type=datasource_type, ) ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - **online_document_result.result.model_dump(), - "datasource_type": datasource_type, - }, + yield from self._transform_message( + messages=online_document_result, + parameters_for_log=parameters_for_log, + datasource_info=datasource_info, ) + case DatasourceProviderType.WEBSITE_CRAWL: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -117,7 +124,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): **datasource_info, "datasource_type": datasource_type, }, - ) + )) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: @@ -149,7 +156,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable_key_list=new_key_list, variable_value=value, ) - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -157,25 +164,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): "file_info": datasource_info, "datasource_type": datasource_type, }, - ) + )) case _: raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, - ) + )) except DatasourceNodeError as e: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, - ) + )) def _generate_parameters( self, @@ -279,3 +286,136 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result + + + + def _transform_message( + self, + messages: Generator[DatasourceInvokeMessage, None, None], + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + DatasourceInvokeMessage.MessageType.IMAGE_LINK, + DatasourceInvokeMessage.MessageType.BINARY_LINK, + DatasourceInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE) + else: + transfer_method = FileTransferMethod.DATASOURCE_FILE + + datasource_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + + mapping = { + "datasource_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert message.meta + + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "datasource_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.DATASOURCE_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + elif message.type == DatasourceInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent( + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + ) + elif message.type == DatasourceInvokeMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) + if self.node_type == NodeType.AGENT: + msg_metadata = message.message.json_object.pop("execution_metadata", {}) + agent_execution_metadata = { + key: value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + json.append(message.message.json_object) + elif message.type == DatasourceInvokeMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == DatasourceInvokeMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"json": json, "files": files, **variables, "text": text}, + metadata={ + WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info, + }, + inputs=parameters_for_log, + ) + ) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 52f119936f..128041a27d 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -60,6 +60,7 @@ def build_from_mapping( FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.REMOTE_URL: _build_from_remote_url, FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, } build_func = build_functions.get(transfer_method) @@ -302,6 +303,52 @@ def _build_from_tool_file( ) +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, +) -> File: + datasource_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.tenant_id == tenant_id, + ) + .first() + ) + + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + + detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) + + return File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=datasource_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=datasource_file.source_url, + related_id=datasource_file.id, + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + ) + def _is_file_valid_with_config( *, input_file_type: str, From 0908f310fc0b02cdba8b3d6a9d4c65988a6f1306 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 13 Jun 2025 17:36:24 +0800 Subject: [PATCH 114/155] feat: webcrawl --- .../datasets/rag_pipeline/rag_pipeline.py | 1 - .../entities/datasource_entities.py | 39 +++---------------- .../website_crawl/website_crawl_plugin.py | 8 ++-- .../website_crawl/website_crawl_provider.py | 1 - api/core/plugin/impl/datasource.py | 13 +++---- .../knowledge_index/knowledge_index_node.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 27 ++++++------- 7 files changed, 31 insertions(+), 60 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 23d402f914..93976bd6f5 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,6 +1,5 @@ import logging -import yaml from flask import request from flask_restful import Resource, reqparse from sqlalchemy.orm import Session diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 8d68c80c81..dd65c85cbc 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage +from core.tools.entities.tool_entities import ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -290,40 +290,13 @@ class WebSiteInfo(BaseModel): """ Website info """ - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") + status: Optional[str] = Field(..., description="crawl job status") web_info_list: Optional[list[WebSiteInfoDetail]] = [] + total: Optional[int] = Field(default=0, description="The total number of websites") + completed: Optional[int] = Field(default=0, description="The number of completed websites") - -class GetWebsiteCrawlResponse(BaseModel): +class WebsiteCrawlMessage(BaseModel): """ Get website crawl response """ - - result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) - - -class DatasourceInvokeMessage(ToolInvokeMessage): - """ - Datasource Invoke Message. - """ - - class WebsiteCrawlMessage(BaseModel): - """ - Website crawl message - """ - - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") - web_info_list: Optional[list[WebSiteInfoDetail]] = [] - - class OnlineDocumentMessage(BaseModel): - """ - Online document message - """ - - workspace_id: str = Field(..., description="The workspace id") - workspace_name: str = Field(..., description="The workspace name") - workspace_icon: str = Field(..., description="The workspace icon") - total: int = Field(..., description="The total number of documents") - pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") \ No newline at end of file + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index e8256b3282..d0e442f31a 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin @@ -6,7 +6,7 @@ from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceProviderType, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -31,12 +31,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def _get_website_crawl( + def get_website_crawl( self, user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[WebsiteCrawlMessage, None, None]: manager = PluginDatasourceManager() return manager.get_website_crawl( diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index a65efb750e..0567f1a480 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,4 +1,3 @@ -from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 83b1a5760b..54325a545f 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,12 +1,12 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -94,17 +94,17 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[WebsiteCrawlMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", - DatasourceInvokeMessage, + WebsiteCrawlMessage, data={ "user_id": user_id, "data": { @@ -119,7 +119,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def get_online_document_pages( self, diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index c63d837106..49c8ec1e69 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,7 @@ import datetime import logging -from collections.abc import Mapping import time +from collections.abc import Mapping from typing import Any, cast from sqlalchemy import func diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index df9fea805c..43b68b3b97 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -17,11 +17,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -43,14 +42,14 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser -from models.oauth import DatasourceProvider from models.workflow import ( Workflow, + WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowType, WorkflowNodeExecutionModel, + WorkflowType, ) from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService @@ -468,15 +467,16 @@ class RagPipelineService: case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_results: list[WebsiteCrawlMessage] = [] + for website_message in datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters={"job_id": job_id}, provider_type=datasource_runtime.datasource_provider_type(), - ) + ): + website_crawl_results.append(website_message) return { - "result": [result for result in website_crawl_result.result], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, + "result": [result for result in website_crawl_results.result], + "status": website_crawl_results.result.status, "provider_type": datasource_node_data.get("provider_type"), } case _: @@ -544,14 +544,15 @@ class RagPipelineService: case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_results: list[WebsiteCrawlMessage] = [] + for website_crawl_result in datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), - ) + ): + website_crawl_results.append(website_crawl_result) return { "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], - "job_id": website_crawl_result.result.job_id, "status": website_crawl_result.result.status, "provider_type": datasource_node_data.get("provider_type"), } From 5ccb8d9736eb3948c81f232920e0f06a2141660d Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 13 Jun 2025 18:22:15 +0800 Subject: [PATCH 115/155] feat: online document --- .../entities/datasource_entities.py | 39 +++++++++++++++---- .../online_document/online_document_plugin.py | 9 ++--- api/core/plugin/impl/datasource.py | 7 ++-- .../nodes/datasource/datasource_node.py | 5 +-- api/services/rag_pipeline/rag_pipeline.py | 4 +- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index dd65c85cbc..b9a0c1f150 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -207,12 +207,6 @@ class DatasourceInvokeFrom(Enum): RAG_PIPELINE = "rag_pipeline" -class GetOnlineDocumentPagesRequest(BaseModel): - """ - Get online document pages request - """ - - class OnlineDocumentPage(BaseModel): """ Online document page @@ -237,7 +231,7 @@ class OnlineDocumentInfo(BaseModel): pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") -class GetOnlineDocumentPagesResponse(BaseModel): +class OnlineDocumentPagesMessage(BaseModel): """ Get online document pages response """ @@ -300,3 +294,32 @@ class WebsiteCrawlMessage(BaseModel): Get website crawl response """ result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) + +class DatasourceMessage(ToolInvokeMessage): + pass + + +class DatasourceInvokeMessage(ToolInvokeMessage): + """ + Datasource Invoke Message. + """ + + class WebsiteCrawlMessage(BaseModel): + """ + Website crawl message + """ + + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + + class OnlineDocumentMessage(BaseModel): + """ + Online document message + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 2ab60cae1e..db73d9a64b 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -8,8 +8,7 @@ from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -39,7 +38,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_pages( diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 54325a545f..06ee00c688 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -4,8 +4,7 @@ from typing import Any from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -129,7 +128,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -139,7 +138,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", - DatasourceInvokeMessage, + OnlineDocumentPagesMessage, data={ "user_id": user_id, "data": { diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index bd4a6e3a56..240eeeb725 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Generator, cast +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,7 +9,6 @@ from core.datasource.entities.datasource_entities import ( DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 43b68b3b97..7af607a96b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -16,7 +16,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin @@ -532,7 +532,7 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + online_document_result: OnlineDocumentPagesMessage = datasource_runtime._get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), From 41fef8a21fe9463e98f7b7ad29cad2cd7b5c45fd Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 16 Jun 2025 13:48:43 +0800 Subject: [PATCH 116/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 6 +- .../website_crawl/website_crawl_plugin.py | 5 +- api/core/tools/entities/tool_entities.py | 2 + .../workflow/graph_engine/entities/event.py | 5 ++ api/core/workflow/nodes/tool/tool_node.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 83 +++---------------- 6 files changed, 25 insertions(+), 78 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 7b8adfe560..c97b3b1d92 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator import services from configs import dify_config from controllers.console import api @@ -453,7 +454,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node( + return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, user_inputs=inputs, @@ -461,8 +462,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): datasource_type=datasource_type, is_published=False ) - - return result + ) class RagPipelinePublishedNodeRunApi(Resource): diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index e8256b3282..87612fff44 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Generator from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceInvokeMessage, DatasourceProviderType, GetWebsiteCrawlResponse, ) @@ -36,7 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_website_crawl( diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 03047c0545..34a86555f7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -188,6 +188,8 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + WEBSITE_CRAWL = "website_crawl" + ONLINE_DOCUMENT = "online_document" type: MessageType = MessageType.TEXT """ diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 9a4939502e..0d8a4ee821 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -273,3 +273,8 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent + + +class DatasourceRunEvent(BaseModel): + status: str = Field(..., description="status") + result: dict[str, Any] = Field(..., description="result") diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aaecc7b989..9a37f0e51c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]): inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", - error_type=type(e).__name__, + error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea ) ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index df9fea805c..a3978c9a5a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -15,6 +15,7 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPagesResponse, GetWebsiteCrawlResponse, @@ -31,7 +32,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -423,69 +424,11 @@ class RagPipelineService: return workflow_node_execution - def run_datasource_workflow_node_status( - self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool - ) -> dict: - """ - Run published workflow datasource - """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") - - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - - from core.datasource.datasource_manager import DatasourceManager - - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get('provider_name'), - plugin_id=datasource_node_data.get('plugin_id'), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=account.id, - datasource_parameters={"job_id": job_id}, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [result for result in website_crawl_result.result], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> dict: + ) -> Generator[DatasourceRunEvent, None, None]: """ Run published workflow datasource """ @@ -532,29 +475,25 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - return { - "result": [page.model_dump() for page in online_document_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } + for message in online_document_result: + yield DatasourceRunEvent( + status="success", + result=message.model_dump(), + ) case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - return { - "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } + yield from website_crawl_result case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") From c5976f5a0920ca508d4cac233820ccf185648580 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 17 Jun 2025 13:51:41 +0800 Subject: [PATCH 117/155] feat(datasource): change datasource result type to event-stream --- .../datasets/rag_pipeline/datasource_auth.py | 3 +- .../rag_pipeline/rag_pipeline_workflow.py | 208 +++++++++--------- .../common/workflow_response_converter.py | 5 +- .../datasource/datasource_file_manager.py | 6 +- .../entities/datasource_entities.py | 25 --- .../online_document/online_document_plugin.py | 8 +- .../datasource/utils/message_transformer.py | 56 ++--- .../website_crawl/website_crawl_plugin.py | 1 - api/core/plugin/impl/datasource.py | 6 +- api/core/tools/entities/tool_entities.py | 2 - .../workflow/graph_engine/entities/event.py | 6 +- .../nodes/datasource/datasource_node.py | 38 ++-- api/core/workflow/nodes/tool/tool_node.py | 2 +- api/services/datasource_provider_service.py | 19 +- api/services/rag_pipeline/rag_pipeline.py | 163 ++++++++------ 15 files changed, 281 insertions(+), 267 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 912981db01..d2136f771b 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -41,7 +41,8 @@ class DatasourcePluginOauthApi(Resource): if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() - redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" + redirect_url = (f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?" + f"provider={provider}&plugin_id={plugin_id}") system_credentials = plugin_oauth_config.system_credentials if system_credentials: system_credentials["redirect_url"] = redirect_url diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c97b3b1d92..616803247c 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from core.app.apps.pipeline.pipeline_generator import PipelineGenerator import services from configs import dify_config from controllers.console import api @@ -24,6 +23,7 @@ from controllers.console.wraps import ( ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db @@ -302,87 +302,87 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): - """ - Run rag pipeline datasource - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - args = parser.parse_args() - - job_id = args.get("job_id") - if job_id == None: - raise ValueError("missing job_id") - datasource_type = args.get("datasource_type") - if datasource_type == None: - raise ValueError("missing datasource_type") - - rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node_status( - pipeline=pipeline, - node_id=node_id, - job_id=job_id, - account=current_user, - datasource_type=datasource_type, - is_published=True - ) - - return result +# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): +# @setup_required +# @login_required +# @account_initialization_required +# @get_rag_pipeline +# def post(self, pipeline: Pipeline, node_id: str): +# """ +# Run rag pipeline datasource +# """ +# # The role of the current user in the ta table must be admin, owner, or editor +# if not current_user.is_editor: +# raise Forbidden() +# +# if not isinstance(current_user, Account): +# raise Forbidden() +# +# parser = reqparse.RequestParser() +# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") +# parser.add_argument("datasource_type", type=str, required=True, location="json") +# args = parser.parse_args() +# +# job_id = args.get("job_id") +# if job_id == None: +# raise ValueError("missing job_id") +# datasource_type = args.get("datasource_type") +# if datasource_type == None: +# raise ValueError("missing datasource_type") +# +# rag_pipeline_service = RagPipelineService() +# result = rag_pipeline_service.run_datasource_workflow_node_status( +# pipeline=pipeline, +# node_id=node_id, +# job_id=job_id, +# account=current_user, +# datasource_type=datasource_type, +# is_published=True +# ) +# +# return result -class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): - """ - Run rag pipeline datasource - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("datasource_type", type=str, required=True, location="json") - args = parser.parse_args() - - job_id = args.get("job_id") - if job_id == None: - raise ValueError("missing job_id") - datasource_type = args.get("datasource_type") - if datasource_type == None: - raise ValueError("missing datasource_type") - - rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node_status( - pipeline=pipeline, - node_id=node_id, - job_id=job_id, - account=current_user, - datasource_type=datasource_type, - is_published=False - ) - - return result - +# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): +# @setup_required +# @login_required +# @account_initialization_required +# @get_rag_pipeline +# def post(self, pipeline: Pipeline, node_id: str): +# """ +# Run rag pipeline datasource +# """ +# # The role of the current user in the ta table must be admin, owner, or editor +# if not current_user.is_editor: +# raise Forbidden() +# +# if not isinstance(current_user, Account): +# raise Forbidden() +# +# parser = reqparse.RequestParser() +# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") +# parser.add_argument("datasource_type", type=str, required=True, location="json") +# args = parser.parse_args() +# +# job_id = args.get("job_id") +# if job_id == None: +# raise ValueError("missing job_id") +# datasource_type = args.get("datasource_type") +# if datasource_type == None: +# raise ValueError("missing datasource_type") +# +# rag_pipeline_service = RagPipelineService() +# result = rag_pipeline_service.run_datasource_workflow_node_status( +# pipeline=pipeline, +# node_id=node_id, +# job_id=job_id, +# account=current_user, +# datasource_type=datasource_type, +# is_published=False +# ) +# +# return result +# class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @@ -425,7 +425,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): return result -class RagPipelineDrafDatasourceNodeRunApi(Resource): +class RagPipelineDraftDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required @@ -447,22 +447,28 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): args = parser.parse_args() inputs = args.get("inputs") - if inputs == None: + if inputs is None: raise ValueError("missing inputs") datasource_type = args.get("datasource_type") - if datasource_type == None: + if datasource_type is None: raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=False - ) - ) + try: + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False + ) + ) + ) + except Exception as e: + print(e) class RagPipelinePublishedNodeRunApi(Resource): @@ -981,17 +987,17 @@ api.add_resource( RagPipelinePublishedDatasourceNodeRunApi, "/rag/pipelines//workflows/published/datasource/nodes//run", ) -api.add_resource( - RagPipelinePublishedDatasourceNodeRunStatusApi, - "/rag/pipelines//workflows/published/datasource/nodes//run-status", -) -api.add_resource( - RagPipelineDraftDatasourceNodeRunStatusApi, - "/rag/pipelines//workflows/draft/datasource/nodes//run-status", -) +# api.add_resource( +# RagPipelinePublishedDatasourceNodeRunStatusApi, +# "/rag/pipelines//workflows/published/datasource/nodes//run-status", +# ) +# api.add_resource( +# RagPipelineDraftDatasourceNodeRunStatusApi, +# "/rag/pipelines//workflows/draft/datasource/nodes//run-status", +# ) api.add_resource( - RagPipelineDrafDatasourceNodeRunApi, + RagPipelineDraftDatasourceNodeRunApi, "/rag/pipelines//workflows/draft/datasource/nodes//run", ) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 99bb597721..f170d0ee3f 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -186,7 +186,10 @@ class WorkflowResponseConverter: elif event.node_type == NodeType.DATASOURCE: node_data = cast(DatasourceNodeData, event.node_data) manager = PluginDatasourceManager() - provider_entity = manager.fetch_datasource_provider(self._application_generate_entity.app_config.tenant_id, f"{node_data.plugin_id}/{node_data.provider_name}") + provider_entity = manager.fetch_datasource_provider( + self._application_generate_entity.app_config.tenant_id, + f"{node_data.plugin_id}/{node_data.provider_name}" + ) response.data.extras["icon"] = provider_entity.declaration.identity.icon return response diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 6704d4e73a..9a31f682fd 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -14,7 +14,7 @@ from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import MessageFile, UploadFile from models.tools import ToolFile @@ -86,7 +86,7 @@ class DatasourceFileManager: size=len(file_binary), extension=extension, mime_type=mimetype, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_by=user_id, used=False, hash=hashlib.sha3_256(file_binary).hexdigest(), @@ -133,7 +133,7 @@ class DatasourceFileManager: size=len(blob), extension=extension, mime_type=mimetype, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_by=user_id, used=False, hash=hashlib.sha3_256(blob).hexdigest(), diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index b9a0c1f150..9b72966b50 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -298,28 +298,3 @@ class WebsiteCrawlMessage(BaseModel): class DatasourceMessage(ToolInvokeMessage): pass - -class DatasourceInvokeMessage(ToolInvokeMessage): - """ - Datasource Invoke Message. - """ - - class WebsiteCrawlMessage(BaseModel): - """ - Website crawl message - """ - - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") - web_info_list: Optional[list[WebSiteInfoDetail]] = [] - - class OnlineDocumentMessage(BaseModel): - """ - Online document message - """ - - workspace_id: str = Field(..., description="The workspace id") - workspace_name: str = Field(..., description="The workspace name") - workspace_icon: str = Field(..., description="The workspace icon") - total: int = Field(..., description="The total number of documents") - pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index db73d9a64b..c1e015fd3a 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -5,7 +5,7 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, - DatasourceInvokeMessage, + DatasourceMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, @@ -33,7 +33,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def _get_online_document_pages( + def get_online_document_pages( self, user_id: str, datasource_parameters: Mapping[str, Any], @@ -51,12 +51,12 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def _get_online_document_page_content( + def get_online_document_page_content( self, user_id: str, datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[DatasourceMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_page_content( diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index bd99387e8d..9bc57235d8 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,7 +4,7 @@ from mimetypes import guess_extension from typing import Optional from core.datasource.datasource_file_manager import DatasourceFileManager -from core.datasource.entities.datasource_entities import DatasourceInvokeMessage +from core.datasource.entities.datasource_entities import DatasourceMessage from core.file import File, FileTransferMethod, FileType logger = logging.getLogger(__name__) @@ -14,23 +14,23 @@ class DatasourceFileMessageTransformer: @classmethod def transform_datasource_invoke_messages( cls, - messages: Generator[DatasourceInvokeMessage, None, None], + messages: Generator[DatasourceMessage, None, None], user_id: str, tenant_id: str, conversation_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[DatasourceMessage, None, None]: """ Transform datasource message and handle file download """ for message in messages: - if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}: + if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}: yield message - elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance( - message.message, DatasourceInvokeMessage.TextMessage + elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance( + message.message, DatasourceMessage.TextMessage ): # try to download image try: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceMessage.TextMessage) file = DatasourceFileManager.create_file_by_url( user_id=user_id, @@ -41,20 +41,20 @@ class DatasourceFileMessageTransformer: url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=message.meta.copy() if message.meta is not None else {}, ) except Exception as e: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.TEXT, - message=DatasourceInvokeMessage.TextMessage( + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage( text=f"Failed to download image: {message.message.text}: {e}" ), meta=message.meta.copy() if message.meta is not None else {}, ) - elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceMessage.MessageType.BLOB: # get mime type and save blob to storage meta = message.meta or {} @@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer: filename = meta.get("file_name", None) # if message is str, encode it to bytes - if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage): + if not isinstance(message.message, DatasourceMessage.BlobMessage): raise ValueError("unexpected message type") # FIXME: should do a type check here. @@ -81,18 +81,18 @@ class DatasourceFileMessageTransformer: # check if file is image if "image" in mimetype: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.BINARY_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.BINARY_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) - elif message.type == DatasourceInvokeMessage.MessageType.FILE: + elif message.type == DatasourceMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): @@ -100,15 +100,15 @@ class DatasourceFileMessageTransformer: assert file.related_id is not None url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield DatasourceInvokeMessage( - type=DatasourceInvokeMessage.MessageType.LINK, - message=DatasourceInvokeMessage.TextMessage(text=url), + yield DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 1625670165..d0e442f31a 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -5,7 +5,6 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, - DatasourceInvokeMessage, DatasourceProviderType, WebsiteCrawlMessage, ) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 06ee00c688..66469b43b4 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -2,7 +2,7 @@ from collections.abc import Generator, Mapping from typing import Any from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, + DatasourceMessage, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, WebsiteCrawlMessage, @@ -164,7 +164,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[DatasourceMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -174,7 +174,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", - DatasourceInvokeMessage, + DatasourceMessage, data={ "user_id": user_id, "data": { diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 34a86555f7..03047c0545 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -188,8 +188,6 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" - WEBSITE_CRAWL = "website_crawl" - ONLINE_DOCUMENT = "online_document" type: MessageType = MessageType.TEXT """ diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 0d8a4ee821..fbf591eb8f 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -277,4 +277,8 @@ InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | Bas class DatasourceRunEvent(BaseModel): status: str = Field(..., description="status") - result: dict[str, Any] = Field(..., description="result") + data: Mapping[str,Any] | list = Field(..., description="result") + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") + time_consuming: Optional[float] = Field(..., description="time consuming") + diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 240eeeb725..0e3decc7b4 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, + DatasourceMessage, DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, @@ -100,8 +100,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[DatasourceInvokeMessage, None, None] = ( - datasource_runtime._get_online_document_page_content( + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), provider_type=datasource_type, @@ -290,7 +290,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): def _transform_message( self, - messages: Generator[DatasourceInvokeMessage, None, None], + messages: Generator[DatasourceMessage, None, None], parameters_for_log: dict[str, Any], datasource_info: dict[str, Any], ) -> Generator: @@ -313,11 +313,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): for message in message_stream: if message.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.BINARY_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, + DatasourceMessage.MessageType.IMAGE_LINK, + DatasourceMessage.MessageType.BINARY_LINK, + DatasourceMessage.MessageType.IMAGE, }: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceMessage.TextMessage) url = message.message.text if message.meta: @@ -344,9 +344,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) files.append(file) - elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceMessage.TextMessage) assert message.meta datasource_file_id = message.message.text.split("/")[-1].split(".")[0] @@ -367,14 +367,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) ) - elif message.type == DatasourceInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + elif message.type == DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) text += message.message.text yield RunStreamChunkEvent( chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] ) - elif message.type == DatasourceInvokeMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) + elif message.type == DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) if self.node_type == NodeType.AGENT: msg_metadata = message.message.json_object.pop("execution_metadata", {}) agent_execution_metadata = { @@ -383,13 +383,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): if key in WorkflowNodeExecutionMetadataKey.__members__.values() } json.append(message.message.json_object) - elif message.type == DatasourceInvokeMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + elif message.type == DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) - elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) + elif message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -404,7 +404,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == DatasourceInvokeMessage.MessageType.FILE: + elif message.type == DatasourceMessage.MessageType.FILE: assert message.meta is not None files.append(message.meta["file"]) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 9a37f0e51c..aaecc7b989 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]): inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", - error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea + error_type=type(e).__name__, ) ) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 64fa97197d..80e903bd46 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -32,7 +32,11 @@ class DatasourceProviderService: :param credentials: """ credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials ) if credential_valid: # Get all provider configurations of the current workspace @@ -104,7 +108,8 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}") # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -144,7 +149,8 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}") # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -161,7 +167,12 @@ class DatasourceProviderService: return copy_credentials_list - def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: + def update_datasource_credentials(self, + tenant_id: str, + auth_id: str, + provider: str, + plugin_id: str, + credentials: dict) -> None: """ update datasource credentials. """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1d61677bea..a5f2135100 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -15,7 +15,6 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, DatasourceProviderType, OnlineDocumentPagesMessage, WebsiteCrawlMessage, @@ -423,70 +422,71 @@ class RagPipelineService: return workflow_node_execution - def run_datasource_workflow_node_status( - self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool - ) -> dict: - """ - Run published workflow datasource - """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") - - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - - from core.datasource.datasource_manager import DatasourceManager - - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get('provider_name'), - plugin_id=datasource_node_data.get('plugin_id'), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_results: list[WebsiteCrawlMessage] = [] - for website_message in datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters={"job_id": job_id}, - provider_type=datasource_runtime.datasource_provider_type(), - ): - website_crawl_results.append(website_message) - return { - "result": [result for result in website_crawl_results.result], - "status": website_crawl_results.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + # def run_datasource_workflow_node_status( + # self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, + # datasource_type: str, is_published: bool + # ) -> dict: + # """ + # Run published workflow datasource + # """ + # if is_published: + # # fetch published workflow by app_model + # workflow = self.get_published_workflow(pipeline=pipeline) + # else: + # workflow = self.get_draft_workflow(pipeline=pipeline) + # if not workflow: + # raise ValueError("Workflow not initialized") + # + # # run draft workflow node + # datasource_node_data = None + # start_at = time.perf_counter() + # datasource_nodes = workflow.graph_dict.get("nodes", []) + # for datasource_node in datasource_nodes: + # if datasource_node.get("id") == node_id: + # datasource_node_data = datasource_node.get("data", {}) + # break + # if not datasource_node_data: + # raise ValueError("Datasource node data not found") + # + # from core.datasource.datasource_manager import DatasourceManager + # + # datasource_runtime = DatasourceManager.get_datasource_runtime( + # provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + # datasource_name=datasource_node_data.get("datasource_name"), + # tenant_id=pipeline.tenant_id, + # datasource_type=DatasourceProviderType(datasource_type), + # ) + # datasource_provider_service = DatasourceProviderService() + # credentials = datasource_provider_service.get_real_datasource_credentials( + # tenant_id=pipeline.tenant_id, + # provider=datasource_node_data.get('provider_name'), + # plugin_id=datasource_node_data.get('plugin_id'), + # ) + # if credentials: + # datasource_runtime.runtime.credentials = credentials[0].get("credentials") + # match datasource_type: + # + # case DatasourceProviderType.WEBSITE_CRAWL: + # datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + # website_crawl_results: list[WebsiteCrawlMessage] = [] + # for website_message in datasource_runtime.get_website_crawl( + # user_id=account.id, + # datasource_parameters={"job_id": job_id}, + # provider_type=datasource_runtime.datasource_provider_type(), + # ): + # website_crawl_results.append(website_message) + # return { + # "result": [result for result in website_crawl_results.result], + # "status": website_crawl_results.result.status, + # "provider_type": datasource_node_data.get("provider_type"), + # } + # case _: + # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> Generator[DatasourceRunEvent, None, None]: + ) -> Generator[str, None, None]: """ Run published workflow datasource """ @@ -533,25 +533,40 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - for message in online_document_result: - yield DatasourceRunEvent( - status="success", - result=message.model_dump(), + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] =\ + datasource_runtime.get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), ) + start_time = time.time() + for message in online_document_result: + end_time = time.time() + online_document_event = DatasourceRunEvent( + status="completed", + data=message.result, + time_consuming=round(end_time - start_time, 2) + ) + yield json.dumps(online_document_event.model_dump()) case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl( + website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - yield from website_crawl_result + start_time = time.time() + for message in website_crawl_result: + end_time = time.time() + crawl_event = DatasourceRunEvent( + status=message.result.status, + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming = round(end_time - start_time, 2) + ) + yield json.dumps(crawl_event.model_dump()) case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") @@ -952,7 +967,9 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") - max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() + max_position = db.session.query( + func.max(PipelineCustomizedTemplate.position)).filter( + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) From f37e28a3689474a8b1ea7d1f47113a81081c1b2b Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 17 Jun 2025 13:54:25 +0800 Subject: [PATCH 118/155] feat(datasource): Comment out the datasource_file_manager. --- api/core/datasource/datasource_file_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 9a31f682fd..51296b64d2 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -239,6 +239,6 @@ class DatasourceFileManager: # init tool_file_parser -from core.file.datasource_file_parser import datasource_file_manager - -datasource_file_manager["manager"] = DatasourceFileManager +# from core.file.datasource_file_parser import datasource_file_manager +# +# datasource_file_manager["manager"] = DatasourceFileManager From 1d2ee9020c3413db8883097740f8df14f954c971 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 17 Jun 2025 14:04:55 +0800 Subject: [PATCH 119/155] r2 --- api/core/variables/variables.py | 4 ++-- api/fields/workflow_fields.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 1fe0e36a47..54aeec61e9 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from uuid import uuid4 from pydantic import BaseModel, Field @@ -104,7 +104,7 @@ class RAGPipelineVariable(BaseModel): max_length: int | None = Field( description="max length, applicable to text-input, paragraph, and file-list", default=0 ) - default_value: str | None = Field(description="default value", default="") + default_value: Any = Field(description="default value", default="") placeholder: str | None = Field(description="placeholder", default="") unit: str | None = Field(description="unit, applicable to Number", default="") tooltips: str | None = Field(description="helpful text", default="") diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index c138266b14..36249d2ae9 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -47,6 +47,7 @@ pipeline_variable_fields = { "belong_to_node_id": fields.String, "max_length": fields.Integer, "required": fields.Boolean, + "unit": fields.String, "default_value": fields.Raw, "options": fields.List(fields.String), "placeholder": fields.String, From 7c41f71248967cfacb0a653076191ff6667d8bfd Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 17 Jun 2025 18:11:38 +0800 Subject: [PATCH 120/155] r2 --- api/services/dataset_service.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 5384228b84..8719eb3be4 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -333,6 +333,13 @@ class DatasetService: if external_retrieval_model: dataset.retrieval_model = external_retrieval_model dataset.name = data.get("name", dataset.name) + # check if dataset name is exists + if db.session.query(Dataset).filter( + Dataset.id != dataset_id, + Dataset.name == dataset.name, + Dataset.tenant_id == dataset.tenant_id, + ).first(): + raise ValueError("Dataset name already exists") dataset.description = data.get("description", "") permission = data.get("permission") if permission: From 739ebf211799e15b217c06012757c493e95583ec Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 17 Jun 2025 18:24:09 +0800 Subject: [PATCH 121/155] feat(datasource): change datasource result type to event-stream --- .../rag_pipeline/rag_pipeline_workflow.py | 4 +-- api/core/rag/entities/event.py | 30 ++++++++++++++++++ .../workflow/graph_engine/entities/event.py | 7 ----- api/services/rag_pipeline/rag_pipeline.py | 31 +++++++++++-------- 4 files changed, 50 insertions(+), 22 deletions(-) create mode 100644 api/core/rag/entities/event.py diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 616803247c..7909bb9609 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -406,10 +406,10 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): args = parser.parse_args() inputs = args.get("inputs") - if inputs == None: + if inputs is None: raise ValueError("missing inputs") datasource_type = args.get("datasource_type") - if datasource_type == None: + if datasource_type is None: raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py new file mode 100644 index 0000000000..8e644fcf85 --- /dev/null +++ b/api/core/rag/entities/event.py @@ -0,0 +1,30 @@ +from collections.abc import Mapping +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class DatasourceStreamEvent(Enum): + """ + Datasource Stream event + """ + PROCESSING = "processing" + COMPLETED = "completed" + + +class BaseDatasourceEvent(BaseModel): + pass + +class DatasourceCompletedEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.COMPLETED.value + data: Mapping[str,Any] | list = Field(..., description="result") + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") + time_consuming: Optional[float] = Field(..., description="time consuming") + +class DatasourceProcessingEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.PROCESSING.value + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") + diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index fbf591eb8f..063216dd49 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -275,10 +275,3 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent -class DatasourceRunEvent(BaseModel): - status: str = Field(..., description="status") - data: Mapping[str,Any] | list = Field(..., description="result") - total: Optional[int] = Field(..., description="total") - completed: Optional[int] = Field(..., description="completed") - time_consuming: Optional[float] = Field(..., description="time consuming") - diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a5f2135100..ccb920238d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -21,6 +21,7 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin +from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -30,7 +31,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent +from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -486,7 +487,7 @@ class RagPipelineService: def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> Generator[str, None, None]: + ) -> Generator[BaseDatasourceEvent, None, None]: """ Run published workflow datasource """ @@ -542,12 +543,11 @@ class RagPipelineService: start_time = time.time() for message in online_document_result: end_time = time.time() - online_document_event = DatasourceRunEvent( - status="completed", + online_document_event = DatasourceCompletedEvent( data=message.result, time_consuming=round(end_time - start_time, 2) ) - yield json.dumps(online_document_event.model_dump()) + yield online_document_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) @@ -559,14 +559,19 @@ class RagPipelineService: start_time = time.time() for message in website_crawl_result: end_time = time.time() - crawl_event = DatasourceRunEvent( - status=message.result.status, - data=message.result.web_info_list, - total=message.result.total, - completed=message.result.completed, - time_consuming = round(end_time - start_time, 2) - ) - yield json.dumps(crawl_event.model_dump()) + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2) + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") From cf66d111ba9b4d9946607429a9198909cdcc1a53 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 17 Jun 2025 18:29:02 +0800 Subject: [PATCH 122/155] feat(datasource): change datasource result type to event-stream --- api/core/rag/entities/event.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 8e644fcf85..4921c94557 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -9,8 +9,8 @@ class DatasourceStreamEvent(Enum): """ Datasource Stream event """ - PROCESSING = "processing" - COMPLETED = "completed" + PROCESSING = "datasource_processing" + COMPLETED = "datasource_completed" class BaseDatasourceEvent(BaseModel): From 7f7ea92a457f92225a3da66650b787ff82525b0c Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 17 Jun 2025 19:06:17 +0800 Subject: [PATCH 123/155] r2 --- api/controllers/console/datasets/datasets.py | 2 +- .../console/datasets/datasets_document.py | 30 +++++++ .../datasets/rag_pipeline/datasource_auth.py | 11 ++- .../rag_pipeline/rag_pipeline_workflow.py | 15 ++-- .../common/workflow_response_converter.py | 2 +- .../app/apps/pipeline/pipeline_generator.py | 12 ++- .../entities/datasource_entities.py | 5 +- .../website_crawl/website_crawl_provider.py | 1 - .../workflow/graph_engine/entities/event.py | 3 +- .../nodes/datasource/datasource_node.py | 87 ++++++++++--------- .../knowledge_index/knowledge_index_node.py | 12 ++- api/factories/file_factory.py | 1 + ..._15_1558-b35c3db83d09_add_pipeline_info.py | 2 +- ...7_1905-70a0fc0c013f_add_pipeline_info_7.py | 45 ++++++++++ api/models/dataset.py | 34 ++++++-- api/services/dataset_service.py | 14 +-- api/services/datasource_provider_service.py | 44 +++++----- .../customized/customized_retrieval.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 39 +++++---- 19 files changed, 243 insertions(+), 118 deletions(-) create mode 100644 api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ceaa9ec4fa..644bcbddb1 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -283,7 +283,7 @@ class DatasetApi(Resource): location="json", help="Invalid external knowledge api id.", ) - + parser.add_argument( "icon_info", type=dict, diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 60fa1731ca..e5fde58d04 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -52,6 +52,7 @@ from fields.document_fields import ( ) from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from tasks.add_document_to_index_task import add_document_to_index_task @@ -1092,6 +1093,35 @@ class WebsiteDocumentSyncApi(DocumentResource): return {"result": "success"}, 200 +class DocumentPipelineExecutionLogApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound("Document not found.") + log = ( + db.session.query(DocumentPipelineExecutionLog) + .filter_by(document_id=document_id) + .order_by(DocumentPipelineExecutionLog.created_at.desc()) + .first() + ) + if not log: + return {"datasource_info": None, "datasource_type": None, "input_data": None}, 200 + return { + "datasource_info": log.datasource_info, + "datasource_type": log.datasource_type, + "input_data": log.input_data, + }, 200 + + api.add_resource(GetProcessRuleApi, "/datasets/process-rule") api.add_resource(DatasetDocumentListApi, "/datasets//documents") api.add_resource(DatasetInitApi, "/datasets/init") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index d2136f771b..21a7b998f0 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -41,8 +41,9 @@ class DatasourcePluginOauthApi(Resource): if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() - redirect_url = (f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?" - f"provider={provider}&plugin_id={plugin_id}") + redirect_url = ( + f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" + ) system_credentials = plugin_oauth_config.system_credentials if system_credentials: system_credentials["redirect_url"] = redirect_url @@ -123,9 +124,7 @@ class DatasourceAuth(Resource): args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=args["provider"], - plugin_id=args["plugin_id"] + tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"] ) return {"result": datasources}, 200 @@ -146,7 +145,7 @@ class DatasourceAuthUpdateDeleteApi(Resource): tenant_id=current_user.current_tenant_id, auth_id=auth_id, provider=args["provider"], - plugin_id=args["plugin_id"] + plugin_id=args["plugin_id"], ) return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 616803247c..00cd36b649 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -384,6 +384,7 @@ class PublishedRagPipelineRunApi(Resource): # return result # + class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @@ -419,7 +420,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): user_inputs=inputs, account=current_user, datasource_type=datasource_type, - is_published=True + is_published=True, ) return result @@ -458,12 +459,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): return helper.compact_generate_response( PipelineGenerator.convert_to_event_stream( rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=False + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, ) ) ) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index f170d0ee3f..aa74f8c318 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -188,7 +188,7 @@ class WorkflowResponseConverter: manager = PluginDatasourceManager() provider_entity = manager.fetch_datasource_provider( self._application_generate_entity.app_config.tenant_id, - f"{node_data.plugin_id}/{node_data.provider_name}" + f"{node_data.plugin_id}/{node_data.provider_name}", ) response.data.extras["icon"] = provider_entity.declaration.identity.icon diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index a2123fdc49..ec565fe2e5 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -33,7 +33,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.dataset import Document, Pipeline +from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.dataset_service import DocumentService @@ -136,6 +136,16 @@ class PipelineGenerator(BaseAppGenerator): document_id = None if invoke_from == InvokeFrom.PUBLISHED: document_id = documents[i].id + document_pipeline_execution_log = DocumentPipelineExecutionLog( + document_id=document_id, + datasource_type=datasource_type, + datasource_info=datasource_info, + input_data=inputs, + pipeline_id=pipeline.id, + created_by=user.id, + ) + db.session.add(document_pipeline_execution_log) + db.session.commit() application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), app_config=pipeline_config, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 9b72966b50..d072b8541b 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -284,17 +284,20 @@ class WebSiteInfo(BaseModel): """ Website info """ + status: Optional[str] = Field(..., description="crawl job status") web_info_list: Optional[list[WebSiteInfoDetail]] = [] total: Optional[int] = Field(default=0, description="The total number of websites") completed: Optional[int] = Field(default=0, description="The number of completed websites") + class WebsiteCrawlMessage(BaseModel): """ Get website crawl response """ + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) + class DatasourceMessage(ToolInvokeMessage): pass - diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 0567f1a480..8c0f20ce2d 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -43,7 +43,6 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index fbf591eb8f..061a69e009 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -277,8 +277,7 @@ InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | Bas class DatasourceRunEvent(BaseModel): status: str = Field(..., description="status") - data: Mapping[str,Any] | list = Field(..., description="result") + data: Mapping[str, Any] | list = Field(..., description="result") total: Optional[int] = Field(..., description="total") completed: Optional[int] = Field(..., description="completed") time_consuming: Optional[float] = Field(..., description="time consuming") - diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 0e3decc7b4..ab4477f538 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -74,12 +74,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): except DatasourceNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to get datasource runtime: {str(e)}", - error_type=type(e).__name__, - ) + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) ) # get parameters @@ -114,16 +114,17 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) case DatasourceProviderType.WEBSITE_CRAWL: - - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - **datasource_info, - "datasource_type": datasource_type, - }, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + **datasource_info, + "datasource_type": datasource_type, + }, + ) + ) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: @@ -155,33 +156,39 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable_key_list=new_key_list, variable_value=value, ) - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "file_info": datasource_info, - "datasource_type": datasource_type, - }, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file_info": datasource_info, + "datasource_type": datasource_type, + }, + ) + ) case _: raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to transform datasource message: {str(e)}", - error_type=type(e).__name__, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", + error_type=type(e).__name__, + ) + ) except DatasourceNodeError as e: - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to invoke datasource: {str(e)}", - error_type=type(e).__name__, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to invoke datasource: {str(e)}", + error_type=type(e).__name__, + ) + ) def _generate_parameters( self, @@ -286,8 +293,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return result - - def _transform_message( self, messages: Generator[DatasourceMessage, None, None], diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 49c8ec1e69..2c45bf4073 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -123,10 +123,14 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): # update document status document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.word_count = db.session.query(func.sum(DocumentSegment.word_count)).filter( - DocumentSegment.document_id == document.id, - DocumentSegment.dataset_id == dataset.id, - ).scalar() + document.word_count = ( + db.session.query(func.sum(DocumentSegment.word_count)) + .filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ) + .scalar() + ) db.session.add(document) # update document segment status db.session.query(DocumentSegment).filter( diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 128041a27d..81606594e0 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -349,6 +349,7 @@ def _build_from_datasource_file( storage_key=datasource_file.key, ) + def _is_file_valid_with_config( *, input_file_type: str, diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py index 4d726cecb1..503842b797 100644 --- a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'b35c3db83d09' -down_revision = '2adcbe1f5dfb' +down_revision = '4474872b0ee6' branch_labels = None depends_on = None diff --git a/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py new file mode 100644 index 0000000000..a695adc74a --- /dev/null +++ b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py @@ -0,0 +1,45 @@ +"""add_pipeline_info_7 + +Revision ID: 70a0fc0c013f +Revises: 224fba149d48 +Create Date: 2025-06-17 19:05:39.920953 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '70a0fc0c013f' +down_revision = '224fba149d48' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('document_pipeline_execution_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('datasource_type', sa.String(length=255), nullable=False), + sa.Column('datasource_info', sa.Text(), nullable=False), + sa.Column('input_data', sa.JSON(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey') + ) + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.drop_index('document_pipeline_execution_logs_document_id_idx') + + op.drop_table('document_pipeline_execution_logs') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 5d18eaff49..16d1865a83 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -75,12 +75,16 @@ class Dataset(Base): @property def total_available_documents(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - ).scalar() + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def dataset_keyword_table(self): @@ -325,6 +329,7 @@ class DatasetProcessRule(Base): except JSONDecodeError: return None + class Document(Base): __tablename__ = "documents" __table_args__ = ( @@ -1248,3 +1253,20 @@ class Pipeline(Base): # type: ignore[name-defined] @property def dataset(self): return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() + + +class DocumentPipelineExecutionLog(Base): + __tablename__ = "document_pipeline_execution_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"), + db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + pipeline_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + datasource_type = db.Column(db.String(255), nullable=False) + datasource_info = db.Column(db.Text, nullable=False) + input_data = db.Column(db.JSON, nullable=False) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8719eb3be4..8c88a51ed7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -334,11 +334,15 @@ class DatasetService: dataset.retrieval_model = external_retrieval_model dataset.name = data.get("name", dataset.name) # check if dataset name is exists - if db.session.query(Dataset).filter( - Dataset.id != dataset_id, - Dataset.name == dataset.name, - Dataset.tenant_id == dataset.tenant_id, - ).first(): + if ( + db.session.query(Dataset) + .filter( + Dataset.id != dataset_id, + Dataset.name == dataset.name, + Dataset.tenant_id == dataset.tenant_id, + ) + .first() + ): raise ValueError("Dataset name already exists") dataset.description = data.get("description", "") permission = data.get("permission") diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 80e903bd46..fa01fe0afe 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -36,7 +36,7 @@ class DatasourceProviderService: user_id=current_user.id, provider=provider, plugin_id=plugin_id, - credentials=credentials + credentials=credentials, ) if credential_valid: # Get all provider configurations of the current workspace @@ -47,9 +47,8 @@ class DatasourceProviderService: ) provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}" - ) + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value @@ -73,9 +72,9 @@ class DatasourceProviderService: :param credential_form_schemas: :return: """ - datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, - provider_id=provider_id - ) + datasource_provider = self.provider_manager.fetch_datasource_provider( + tenant_id=tenant_id, provider_id=provider_id + ) credential_form_schemas = datasource_provider.declaration.credentials_schema secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: @@ -108,8 +107,9 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -149,8 +149,9 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -166,18 +167,18 @@ class DatasourceProviderService: return copy_credentials_list - - def update_datasource_credentials(self, - tenant_id: str, - auth_id: str, - provider: str, - plugin_id: str, - credentials: dict) -> None: + def update_datasource_credentials( + self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict + ) -> None: """ update datasource credentials. """ credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, user_id=current_user.id, provider=provider,plugin_id=plugin_id, credentials=credentials + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials, ) if credential_valid: # Get all provider configurations of the current workspace @@ -188,9 +189,8 @@ class DatasourceProviderService: ) provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}" - ) + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) if not datasource_provider: raise ValueError("Datasource provider not found") else: diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index ca94f7f47a..7280408889 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -66,7 +66,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ) if not pipeline_template: return None - + dsl_data = yaml.safe_load(pipeline_template.yaml_content) graph_data = dsl_data.get("workflow", {}).get("graph", {}) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a5f2135100..87b13ba98d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -484,8 +484,13 @@ class RagPipelineService: # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, - is_published: bool + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, ) -> Generator[str, None, None]: """ Run published workflow datasource @@ -525,27 +530,26 @@ class RagPipelineService: datasource_provider_service = DatasourceProviderService() credentials = datasource_provider_service.get_real_datasource_credentials( tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get('provider_name'), - plugin_id=datasource_node_data.get('plugin_id'), + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), ) if credentials: datasource_runtime.runtime.credentials = credentials[0].get("credentials") match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[OnlineDocumentPagesMessage, None, None] =\ + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( datasource_runtime.get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) + ) start_time = time.time() for message in online_document_result: end_time = time.time() online_document_event = DatasourceRunEvent( - status="completed", - data=message.result, - time_consuming=round(end_time - start_time, 2) + status="completed", data=message.result, time_consuming=round(end_time - start_time, 2) ) yield json.dumps(online_document_event.model_dump()) @@ -564,7 +568,7 @@ class RagPipelineService: data=message.result.web_info_list, total=message.result.total, completed=message.result.completed, - time_consuming = round(end_time - start_time, 2) + time_consuming=round(end_time - start_time, 2), ) yield json.dumps(crawl_event.model_dump()) case _: @@ -781,9 +785,7 @@ class RagPipelineService: raise ValueError("Datasource node data not found") variables = datasource_node_data.get("variables", {}) if variables: - variables_map = { - item["variable"]: item for item in variables - } + variables_map = {item["variable"]: item for item in variables} else: return [] datasource_parameters = datasource_node_data.get("datasource_parameters", {}) @@ -813,9 +815,7 @@ class RagPipelineService: raise ValueError("Datasource node data not found") variables = datasource_node_data.get("variables", {}) if variables: - variables_map = { - item["variable"]: item for item in variables - } + variables_map = {item["variable"]: item for item in variables} else: return [] datasource_parameters = datasource_node_data.get("datasource_parameters", {}) @@ -967,11 +967,14 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") - max_position = db.session.query( - func.max(PipelineCustomizedTemplate.position)).filter( - PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() + max_position = ( + db.session.query(func.max(PipelineCustomizedTemplate.position)) + .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) + .scalar() + ) from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) pipeline_customized_template = PipelineCustomizedTemplate( From ac917bb56d929492fadd32f69ada0c61f6ea9e07 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 18 Jun 2025 11:05:52 +0800 Subject: [PATCH 124/155] r2 --- .../app/apps/pipeline/pipeline_generator.py | 90 ++++++------------- 1 file changed, 29 insertions(+), 61 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ec565fe2e5..e726ad4841 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -32,6 +32,7 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db +from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom @@ -209,25 +210,22 @@ class PipelineGenerator(BaseAppGenerator): # run in child thread context = contextvars.copy_context() - @copy_current_request_context - def worker_with_context(): - # Run the worker within the copied context - return context.run( - self._generate, - flask_app=current_app._get_current_object(), # type: ignore - context=context, - pipeline=pipeline, - workflow_id=workflow.id, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - workflow_thread_pool_id=workflow_thread_pool_id, - ) - - worker_thread = threading.Thread(target=worker_with_context) + worker_thread = threading.Thread( + target=self._generate, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "pipeline": pipeline, + "workflow_id": workflow.id, + "user": user, + "application_generate_entity": application_generate_entity, + "invoke_from": invoke_from, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "streaming": streaming, + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() # return batch, dataset, documents @@ -282,23 +280,7 @@ class PipelineGenerator(BaseAppGenerator): :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id """ - print("jin ru la 1") - for var, val in context.items(): - var.set(val) - - # FIXME(-LAN-): Save current user before entering new app context - from flask import g - - saved_user = None - if has_request_context() and hasattr(g, "_login_user"): - saved_user = g._login_user - with flask_app.app_context(): - # Restore user in new app context - print("jin ru la 2") - if saved_user is not None: - from flask import g - - g._login_user = saved_user + with preserve_flask_contexts(flask_app, context_vars=context): # init queue manager workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first() if not workflow: @@ -311,20 +293,17 @@ class PipelineGenerator(BaseAppGenerator): ) context = contextvars.copy_context() - @copy_current_request_context - def worker_with_context(): - # Run the worker within the copied context - return context.run( - self._generate_worker, - flask_app=current_app._get_current_object(), # type: ignore - context=context, - queue_manager=queue_manager, - application_generate_entity=application_generate_entity, - workflow_thread_pool_id=workflow_thread_pool_id, - ) - # new thread - worker_thread = threading.Thread(target=worker_with_context) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "queue_manager": queue_manager, + "application_generate_entity": application_generate_entity, + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() @@ -521,20 +500,9 @@ class PipelineGenerator(BaseAppGenerator): :param workflow_thread_pool_id: workflow thread pool id :return: """ - print("jin ru la 3") - for var, val in context.items(): - var.set(val) - from flask import g - saved_user = None - if has_request_context() and hasattr(g, "_login_user"): - saved_user = g._login_user - with flask_app.app_context(): + with preserve_flask_contexts(flask_app, context_vars=context): try: - if saved_user is not None: - from flask import g - - g._login_user = saved_user # workflow app runner = PipelineRunner( application_generate_entity=application_generate_entity, From 6f67a34349749f9d6269a3600e34371c2f0195a2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 18 Jun 2025 14:37:18 +0800 Subject: [PATCH 125/155] r2 qa index --- .../processor/qa_index_processor.py | 34 ++++++++++++++++--- api/core/rag/models/document.py | 14 ++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 5fed36c9b0..8b1bc181d5 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -15,13 +15,15 @@ from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import Document, QAStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -162,11 +164,35 @@ class QAIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): - pass + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): + qa_chunks = QAStructureChunk(**chunks) + documents = [] + for qa_chunk in qa_chunks.qa_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(qa_chunk.question), + "answer": qa_chunk.answer, + } + doc = Document(page_content=qa_chunk.question, metadata=metadata) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + else: + raise ValueError("Indexing technique must be high quality.") def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: - return {"preview": chunks} + qa_chunks = QAStructureChunk(**chunks) + preview = [] + for qa_chunk in qa_chunks.qa_chunks: + preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) + return {"qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks)} def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 97d53123b6..3f82bda2c6 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -60,6 +60,20 @@ class ParentChildStructureChunk(BaseModel): parent_child_chunks: list[ParentChildChunk] +class QAChunk(BaseModel): + """ + QA Chunk. + """ + + question: str + answer: str + +class QAStructureChunk(BaseModel): + """ + QAStructureChunk. + """ + qa_chunks: list[QAChunk] + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. From 224111081bb2217162418162b1940715023eaa74 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Wed, 18 Jun 2025 16:04:40 +0800 Subject: [PATCH 126/155] feat(datasource): change datasource result type to event-stream --- .../rag_pipeline/rag_pipeline_workflow.py | 42 +++++++++---------- .../app/apps/pipeline/pipeline_generator.py | 6 +-- api/core/rag/entities/event.py | 6 +-- api/services/rag_pipeline/rag_pipeline.py | 33 ++++++++------- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 28ab4b1635..b040e70b92 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -414,16 +414,19 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=True, + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, + ) + ) ) - return result class RagPipelineDraftDatasourceNodeRunApi(Resource): @@ -455,21 +458,18 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - try: - return helper.compact_generate_response( - PipelineGenerator.convert_to_event_stream( - rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=False, - ) + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, ) ) - except Exception as e: - print(e) + ) class RagPipelinePublishedNodeRunApi(Resource): diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index e726ad4841..f0921c3442 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -2,14 +2,14 @@ import contextvars import datetime import json import logging -import random +import secrets import threading import time import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Optional, Union, overload -from flask import Flask, copy_current_request_context, current_app, has_request_context +from flask import Flask, current_app from pydantic import ValidationError from sqlalchemy.orm import sessionmaker @@ -110,7 +110,7 @@ class PipelineGenerator(BaseAppGenerator): start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) documents = [] if invoke_from == InvokeFrom.PUBLISHED: for datasource_info in datasource_info_list: diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 4921c94557..59a470c35c 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -19,9 +19,9 @@ class BaseDatasourceEvent(BaseModel): class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value data: Mapping[str,Any] | list = Field(..., description="result") - total: Optional[int] = Field(..., description="total") - completed: Optional[int] = Field(..., description="completed") - time_consuming: Optional[float] = Field(..., description="time consuming") + total: Optional[int] = Field(default=0, description="total") + completed: Optional[int] = Field(default=0, description="completed") + time_consuming: Optional[float] = Field(default=0.0, description="time consuming") class DatasourceProcessingEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.PROCESSING.value diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 26036dc2c5..9ddbb7c083 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -558,21 +558,24 @@ class RagPipelineService: provider_type=datasource_runtime.datasource_provider_type(), ) start_time = time.time() - for message in website_crawl_result: - end_time = time.time() - if message.result.status == "completed": - crawl_event = DatasourceCompletedEvent( - data=message.result.web_info_list, - total=message.result.total, - completed=message.result.completed, - time_consuming=round(end_time - start_time, 2) - ) - else: - crawl_event = DatasourceProcessingEvent( - total=message.result.total, - completed=message.result.completed, - ) - yield crawl_event.model_dump() + try: + for message in website_crawl_result: + end_time = time.time() + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2) + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() + except Exception as e: + print(str(e)) case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") From 2cf980026ee72ca488e1fa0b5ff5ca30449bb1be Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Wed, 18 Jun 2025 16:04:47 +0800 Subject: [PATCH 127/155] feat(datasource): change datasource result type to event-stream --- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index f0921c3442..e3710448bf 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -2,7 +2,7 @@ import contextvars import datetime import json import logging -import secrets +import secrets import threading import time import uuid From 43e5798e13a2ec763774d871778749b449667e6b Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Wed, 18 Jun 2025 16:27:10 +0800 Subject: [PATCH 128/155] feat(datasource): change datasource result type to event-stream --- api/services/rag_pipeline/rag_pipeline.py | 61 ----------------------- 1 file changed, 61 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9ddbb7c083..909df456d4 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -423,67 +423,6 @@ class RagPipelineService: return workflow_node_execution - # def run_datasource_workflow_node_status( - # self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, - # datasource_type: str, is_published: bool - # ) -> dict: - # """ - # Run published workflow datasource - # """ - # if is_published: - # # fetch published workflow by app_model - # workflow = self.get_published_workflow(pipeline=pipeline) - # else: - # workflow = self.get_draft_workflow(pipeline=pipeline) - # if not workflow: - # raise ValueError("Workflow not initialized") - # - # # run draft workflow node - # datasource_node_data = None - # start_at = time.perf_counter() - # datasource_nodes = workflow.graph_dict.get("nodes", []) - # for datasource_node in datasource_nodes: - # if datasource_node.get("id") == node_id: - # datasource_node_data = datasource_node.get("data", {}) - # break - # if not datasource_node_data: - # raise ValueError("Datasource node data not found") - # - # from core.datasource.datasource_manager import DatasourceManager - # - # datasource_runtime = DatasourceManager.get_datasource_runtime( - # provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - # datasource_name=datasource_node_data.get("datasource_name"), - # tenant_id=pipeline.tenant_id, - # datasource_type=DatasourceProviderType(datasource_type), - # ) - # datasource_provider_service = DatasourceProviderService() - # credentials = datasource_provider_service.get_real_datasource_credentials( - # tenant_id=pipeline.tenant_id, - # provider=datasource_node_data.get('provider_name'), - # plugin_id=datasource_node_data.get('plugin_id'), - # ) - # if credentials: - # datasource_runtime.runtime.credentials = credentials[0].get("credentials") - # match datasource_type: - # - # case DatasourceProviderType.WEBSITE_CRAWL: - # datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - # website_crawl_results: list[WebsiteCrawlMessage] = [] - # for website_message in datasource_runtime.get_website_crawl( - # user_id=account.id, - # datasource_parameters={"job_id": job_id}, - # provider_type=datasource_runtime.datasource_provider_type(), - # ): - # website_crawl_results.append(website_message) - # return { - # "result": [result for result in website_crawl_results.result], - # "status": website_crawl_results.result.status, - # "provider_type": datasource_node_data.get("provider_type"), - # } - # case _: - # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool From 09e0a540705c4d85a4214977718dba5e467fef2a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 19 Jun 2025 10:38:10 +0800 Subject: [PATCH 129/155] r2 --- api/controllers/console/datasets/datasets_document.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e5fde58d04..ccd19a546b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1144,3 +1144,5 @@ api.add_resource(DocumentRetryApi, "/datasets//retry") api.add_resource(DocumentRenameApi, "/datasets//documents//rename") api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") +api.add_resource(DocumentPipelineExecutionLogApi, + "/datasets//documents//pipeline-execution-log") From 6ec742539a50eb599781cefbacbd318ee35f3a0b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 19 Jun 2025 10:45:59 +0800 Subject: [PATCH 130/155] r2 --- api/controllers/console/datasets/datasets_document.py | 3 ++- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ccd19a546b..16a00bbd42 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,3 +1,4 @@ +import json import logging from argparse import ArgumentTypeError from datetime import UTC, datetime @@ -1116,7 +1117,7 @@ class DocumentPipelineExecutionLogApi(DocumentResource): if not log: return {"datasource_info": None, "datasource_type": None, "input_data": None}, 200 return { - "datasource_info": log.datasource_info, + "datasource_info": json.loads(log.datasource_info), "datasource_type": log.datasource_type, "input_data": log.input_data, }, 200 diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index e726ad4841..49831007ee 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -140,7 +140,7 @@ class PipelineGenerator(BaseAppGenerator): document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, datasource_type=datasource_type, - datasource_info=datasource_info, + datasource_info=json.dumps(datasource_info), input_data=inputs, pipeline_id=pipeline.id, created_by=user.id, From 82d0a70cb4ae01a085b2cb66a2cb777baf1e89ae Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Thu, 19 Jun 2025 11:10:24 +0800 Subject: [PATCH 131/155] feat(datasource): change datasource result type to event-stream --- .../entities/datasource_entities.py | 3 +- api/core/plugin/impl/datasource.py | 3 +- api/core/rag/entities/event.py | 5 + api/services/rag_pipeline/rag_pipeline.py | 188 ++++++++++-------- 4 files changed, 115 insertions(+), 84 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index d072b8541b..af67eaf761 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -213,10 +213,11 @@ class OnlineDocumentPage(BaseModel): """ page_id: str = Field(..., description="The page id") - page_title: str = Field(..., description="The page title") + page_name: str = Field(..., description="The page title") page_icon: Optional[dict] = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") + parent_id: Optional[str] = Field(None, description="The parent page id") class OnlineDocumentInfo(BaseModel): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 66469b43b4..f2539de8f5 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -135,7 +135,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", OnlineDocumentPagesMessage, @@ -153,7 +153,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def get_online_document_page_content( self, diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 59a470c35c..4acb558531 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -11,11 +11,16 @@ class DatasourceStreamEvent(Enum): """ PROCESSING = "datasource_processing" COMPLETED = "datasource_completed" + ERROR = "datasource_error" class BaseDatasourceEvent(BaseModel): pass +class DatasourceErrorEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.ERROR.value + error: str = Field(..., description="error message") + class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value data: Mapping[str,Any] | list = Field(..., description="result") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 909df456d4..1f9337665b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,4 +1,5 @@ import json +import logging import re import threading import time @@ -21,7 +22,12 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent +from core.rag.entities.event import ( + BaseDatasourceEvent, + DatasourceCompletedEvent, + DatasourceErrorEvent, + DatasourceProcessingEvent, +) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -61,6 +67,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory +logger = logging.getLogger(__name__) class RagPipelineService: @classmethod @@ -430,93 +437,112 @@ class RagPipelineService: """ Run published workflow datasource """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - for key, value in datasource_parameters.items(): - if not user_inputs.get(key): - user_inputs[key] = value["value"] + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] - from core.datasource.datasource_manager import DatasourceManager + from core.datasource.datasource_manager import DatasourceManager - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get("provider_name"), - plugin_id=datasource_node_data.get("plugin_id"), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - case DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( - datasource_runtime.get_online_document_pages( + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( + datasource_runtime.get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + try: + for message in online_document_result: + end_time = time.time() + online_document_event = DatasourceCompletedEvent( + data=message.result, + time_consuming=round(end_time - start_time, 2) + ) + yield online_document_event.model_dump() + except Exception as e: + logger.exception("Error during online document.") + yield DatasourceErrorEvent( + error=str(e) + ).model_dump() + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( + datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), - ) - ) - start_time = time.time() - for message in online_document_result: - end_time = time.time() - online_document_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2) - ) - yield online_document_event.model_dump() - - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - start_time = time.time() - try: - for message in website_crawl_result: - end_time = time.time() - if message.result.status == "completed": - crawl_event = DatasourceCompletedEvent( - data=message.result.web_info_list, - total=message.result.total, - completed=message.result.completed, - time_consuming=round(end_time - start_time, 2) - ) - else: - crawl_event = DatasourceProcessingEvent( - total=message.result.total, - completed=message.result.completed, - ) - yield crawl_event.model_dump() - except Exception as e: - print(str(e)) - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + )) + start_time = time.time() + try: + for message in website_crawl_result: + end_time = time.time() + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2) + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() + except Exception as e: + logger.exception("Error during website crawl.") + yield DatasourceErrorEvent( + error=str(e) + ).model_dump() + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_workflow_node.") + yield DatasourceErrorEvent( + error=str(e) + ).model_dump() def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] From 3d0e288e85108ef038f3cdaaee56b6d1f1946212 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 19 Jun 2025 14:29:39 +0800 Subject: [PATCH 132/155] r2 --- api/core/datasource/entities/datasource_entities.py | 4 +++- api/core/plugin/impl/datasource.py | 2 +- api/core/workflow/nodes/datasource/datasource_node.py | 6 +++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index d072b8541b..75b4dd807f 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -244,7 +244,9 @@ class GetOnlineDocumentPageContentRequest(BaseModel): Get online document page content request """ - online_document_info: OnlineDocumentInfo + workspace_id: str = Field(..., description="The workspace id") + page_id: str = Field(..., description="The page id") + type: str = Field(..., description="The type of the page") class OnlineDocumentPageContent(BaseModel): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 66469b43b4..5e966d2599 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -181,7 +181,7 @@ class PluginDatasourceManager(BasePluginClient): "provider": datasource_provider_id.provider_name, "datasource": datasource_name, "credentials": credentials, - "datasource_parameters": datasource_parameters, + "page": datasource_parameters, }, }, headers={ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index ab4477f538..17108e0a57 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -103,7 +103,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): online_document_result: Generator[DatasourceMessage, None, None] = ( datasource_runtime.get_online_document_page_content( user_id=self.user_id, - datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=datasource_info.get("workspace_id"), + page_id=datasource_info.get("page").get("page_id"), + type=datasource_info.get("type"), + ), provider_type=datasource_type, ) ) From b618f3bd9e5f5ebe39da9fdab063832346d3765e Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 19 Jun 2025 15:30:46 +0800 Subject: [PATCH 133/155] r2 --- .../app/apps/pipeline/pipeline_generator.py | 1 + api/core/plugin/impl/datasource.py | 2 +- .../nodes/datasource/datasource_node.py | 11 ++++++- ...9_1525-a1025f709c06_add_pipeline_info_8.py | 33 +++++++++++++++++++ api/models/dataset.py | 1 + 5 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 49831007ee..de7156129a 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -141,6 +141,7 @@ class PipelineGenerator(BaseAppGenerator): document_id=document_id, datasource_type=datasource_type, datasource_info=json.dumps(datasource_info), + datasource_node_id=start_node_id, input_data=inputs, pipeline_id=pipeline.id, created_by=user.id, diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 5e966d2599..a8e98d2c1a 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -181,7 +181,7 @@ class PluginDatasourceManager(BasePluginClient): "provider": datasource_provider_id.provider_name, "datasource": datasource_name, "credentials": credentials, - "page": datasource_parameters, + "page": datasource_parameters.model_dump(), }, }, headers={ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 17108e0a57..5c1d8523ff 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -29,6 +29,7 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory from models.model import UploadFile +from services.datasource_provider_service import DatasourceProviderService from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from .entities import DatasourceNodeData @@ -100,13 +101,21 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") online_document_result: Generator[DatasourceMessage, None, None] = ( datasource_runtime.get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest( workspace_id=datasource_info.get("workspace_id"), page_id=datasource_info.get("page").get("page_id"), - type=datasource_info.get("type"), + type=datasource_info.get("page").get("type"), ), provider_type=datasource_type, ) diff --git a/api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py b/api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py new file mode 100644 index 0000000000..387aff54b0 --- /dev/null +++ b/api/migrations/versions/2025_06_19_1525-a1025f709c06_add_pipeline_info_8.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_8 + +Revision ID: a1025f709c06 +Revises: 70a0fc0c013f +Create Date: 2025-06-19 15:25:41.263120 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a1025f709c06' +down_revision = '70a0fc0c013f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.add_column(sa.Column('datasource_node_id', sa.String(length=255), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.drop_column('datasource_node_id') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 16d1865a83..e8da241d26 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1267,6 +1267,7 @@ class DocumentPipelineExecutionLog(Base): document_id = db.Column(StringUUID, nullable=False) datasource_type = db.Column(db.String(255), nullable=False) datasource_info = db.Column(db.Text, nullable=False) + datasource_node_id = db.Column(db.String(255), nullable=False) input_data = db.Column(db.JSON, nullable=False) created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) From b12a8eeb90d08aa7e9f38f0a74446815a405bf90 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Fri, 20 Jun 2025 10:09:47 +0800 Subject: [PATCH 134/155] feat(datasource): change datasource result type to event-stream --- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- api/core/plugin/impl/datasource.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index e7ac4ad883..96f8f01032 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -589,7 +589,7 @@ class PipelineGenerator(BaseAppGenerator): if datasource_type == "local_file": name = datasource_info["name"] elif datasource_type == "online_document": - name = datasource_info["page_title"] + name = datasource_info['page']["page_name"] elif datasource_type == "website_crawl": name = datasource_info["title"] else: diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index a501a03332..6fad564dd4 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -170,7 +170,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", DatasourceMessage, @@ -188,7 +188,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] From ca0979dd431649454acae6976c2de87ff9ebe90b Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Mon, 23 Jun 2025 15:18:15 +0800 Subject: [PATCH 135/155] feat(datasource): update fetch_datasource_provider --- api/core/plugin/impl/datasource.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 6fad564dd4..a2351a19da 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -7,12 +7,13 @@ from core.datasource.entities.datasource_entities import ( OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import GenericProviderID, ToolProviderID, DatasourceProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, ) from core.plugin.impl.base import BasePluginClient +from services.tools.tools_transform_service import ToolTransformService class PluginDatasourceManager(BasePluginClient): @@ -40,6 +41,8 @@ class PluginDatasourceManager(BasePluginClient): ) local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + # for provider in response: + # ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) all_response = [local_file_datasource_provider] + response for provider in all_response: @@ -58,7 +61,7 @@ class PluginDatasourceManager(BasePluginClient): if provider_id == "langgenius/file/file": return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - tool_provider_id = ToolProviderID(provider_id) + tool_provider_id = DatasourceProviderID(provider_id) def transformer(json_response: dict[str, Any]) -> dict: data = json_response.get("data") From b77081a19edb4a6ba3081b6b514a72ee051f3180 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Mon, 23 Jun 2025 15:57:37 +0800 Subject: [PATCH 136/155] feat(datasource): update datasource icon --- .../rag_pipeline/rag_pipeline_workflow.py | 1 - .../app/apps/pipeline/pipeline_generator.py | 2 +- api/core/plugin/impl/datasource.py | 6 ++-- api/core/rag/entities/event.py | 7 ++-- api/core/rag/models/document.py | 3 ++ .../workflow/graph_engine/entities/event.py | 2 -- api/services/rag_pipeline/rag_pipeline.py | 36 +++++++++---------- api/services/tools/tools_transform_service.py | 14 +++++++- 8 files changed, 43 insertions(+), 28 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index b040e70b92..da6db303cd 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -428,7 +428,6 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): ) - class RagPipelineDraftDatasourceNodeRunApi(Resource): @setup_required @login_required diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 96f8f01032..8769bcea0d 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -589,7 +589,7 @@ class PipelineGenerator(BaseAppGenerator): if datasource_type == "local_file": name = datasource_info["name"] elif datasource_type == "online_document": - name = datasource_info['page']["page_name"] + name = datasource_info["page"]["page_name"] elif datasource_type == "website_crawl": name = datasource_info["title"] else: diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index a2351a19da..53104b0061 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -7,7 +7,7 @@ from core.datasource.entities.datasource_entities import ( OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID, DatasourceProviderID +from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, @@ -41,8 +41,8 @@ class PluginDatasourceManager(BasePluginClient): ) local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - # for provider in response: - # ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) + for provider in response: + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) all_response = [local_file_datasource_provider] + response for provider in all_response: diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 4acb558531..a36e32fc9c 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -9,6 +9,7 @@ class DatasourceStreamEvent(Enum): """ Datasource Stream event """ + PROCESSING = "datasource_processing" COMPLETED = "datasource_completed" ERROR = "datasource_error" @@ -17,19 +18,21 @@ class DatasourceStreamEvent(Enum): class BaseDatasourceEvent(BaseModel): pass + class DatasourceErrorEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.ERROR.value error: str = Field(..., description="error message") + class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value - data: Mapping[str,Any] | list = Field(..., description="result") + data: Mapping[str, Any] | list = Field(..., description="result") total: Optional[int] = Field(default=0, description="total") completed: Optional[int] = Field(default=0, description="completed") time_consuming: Optional[float] = Field(default=0.0, description="time consuming") + class DatasourceProcessingEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.PROCESSING.value total: Optional[int] = Field(..., description="total") completed: Optional[int] = Field(..., description="completed") - diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 3f82bda2c6..e382ff6b54 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -68,12 +68,15 @@ class QAChunk(BaseModel): question: str answer: str + class QAStructureChunk(BaseModel): """ QAStructureChunk. """ + qa_chunks: list[QAChunk] + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 063216dd49..9a4939502e 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -273,5 +273,3 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent - - diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1f9337665b..333d559bf5 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -69,6 +69,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi logger = logging.getLogger(__name__) + class RagPipelineService: @classmethod def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: @@ -431,8 +432,13 @@ class RagPipelineService: return workflow_node_execution def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, - is_published: bool + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, ) -> Generator[BaseDatasourceEvent, None, None]: """ Run published workflow datasource @@ -497,23 +503,21 @@ class RagPipelineService: for message in online_document_result: end_time = time.time() online_document_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2) + data=message.result, time_consuming=round(end_time - start_time, 2) ) yield online_document_event.model_dump() except Exception as e: logger.exception("Error during online document.") - yield DatasourceErrorEvent( - error=str(e) - ).model_dump() + yield DatasourceErrorEvent(error=str(e)).model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - )) + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) start_time = time.time() try: for message in website_crawl_result: @@ -523,7 +527,7 @@ class RagPipelineService: data=message.result.web_info_list, total=message.result.total, completed=message.result.completed, - time_consuming=round(end_time - start_time, 2) + time_consuming=round(end_time - start_time, 2), ) else: crawl_event = DatasourceProcessingEvent( @@ -533,16 +537,12 @@ class RagPipelineService: yield crawl_event.model_dump() except Exception as e: logger.exception("Error during website crawl.") - yield DatasourceErrorEvent( - error=str(e) - ).model_dump() + yield DatasourceErrorEvent(error=str(e)).model_dump() case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") except Exception as e: logger.exception("Error in run_datasource_workflow_node.") - yield DatasourceErrorEvent( - error=str(e) - ).model_dump() + yield DatasourceErrorEvent(error=str(e)).model_dump() def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..8a73c73a1b 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -56,7 +57,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]): + def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): """ repack provider @@ -77,6 +78,17 @@ class ToolTransformService: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) + elif isinstance(provider, PluginDatasourceProviderEntity): + if provider.plugin_id: + if isinstance(provider.declaration.identity.icon, str): + provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.declaration.identity.icon + ) + else: + provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, + icon=provider.declaration.identity.icon + ) @classmethod def builtin_provider_to_user_provider( From 1ff9c07a92cda9bf01a1daf6300d0d451846e567 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 23 Jun 2025 17:12:08 +0800 Subject: [PATCH 137/155] fix notion dataset rule not found --- api/controllers/console/datasets/datasets_document.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 16a00bbd42..8d302ca05e 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1115,11 +1115,16 @@ class DocumentPipelineExecutionLogApi(DocumentResource): .first() ) if not log: - return {"datasource_info": None, "datasource_type": None, "input_data": None}, 200 + return {"datasource_info": None, + "datasource_type": None, + "input_data": None, + "datasource_node_id": None, + }, 200 return { "datasource_info": json.loads(log.datasource_info), "datasource_type": log.datasource_type, "input_data": log.input_data, + "datasource_node_id": log.datasource_node_id, }, 200 @@ -1145,5 +1150,5 @@ api.add_resource(DocumentRetryApi, "/datasets//retry") api.add_resource(DocumentRenameApi, "/datasets//documents//rename") api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") -api.add_resource(DocumentPipelineExecutionLogApi, +api.add_resource(DocumentPipelineExecutionLogApi, "/datasets//documents//pipeline-execution-log") From e165f4a1021617b7556d857949aceb5734d478e7 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 24 Jun 2025 17:14:16 +0800 Subject: [PATCH 138/155] feat(datasource): add datasource content preview api --- api/controllers/console/__init__.py | 1 + .../datasource_content_preview.py | 52 ++++++++++ api/services/rag_pipeline/rag_pipeline.py | 96 ++++++++++++++++++- 3 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index f17c28dcd4..9d9023f59c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -89,6 +89,7 @@ from .datasets.rag_pipeline import ( rag_pipeline_datasets, rag_pipeline_import, rag_pipeline_workflow, + datasource_content_preview ) # Import explore controllers diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py new file mode 100644 index 0000000000..30836b3da1 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -0,0 +1,52 @@ +from flask_restful import ( # type: ignore + Resource, # type: ignore + reqparse, +) +from werkzeug.exceptions import Forbidden +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import setup_required, account_initialization_required +from libs.login import login_required, current_user +from models import Account +from models.dataset import Pipeline +from controllers.console import api +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class DataSourceContentPreviewApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run datasource content preview + """ + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs is None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type is None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.run_datasource_node_preview( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=True, + ) + +api.add_resource( + DataSourceContentPreviewApi, + "/rag/pipelines//workflows/published/datasource/nodes//preview" +) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 333d559bf5..842676e29a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from typing import Any, Optional, cast +from typing import Any, Optional, cast, Mapping from uuid import uuid4 from flask_login import current_user @@ -18,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, OnlineDocumentPagesMessage, - WebsiteCrawlMessage, + WebsiteCrawlMessage, DatasourceMessage, GetOnlineDocumentPageContentRequest, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin @@ -544,6 +544,98 @@ class RagPipelineService: logger.exception("Error in run_datasource_workflow_node.") yield DatasourceErrorEvent(error=str(e)).model_dump() + def run_datasource_node_preview( + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, + ) -> Mapping[str, Any]: + """ + Run published workflow datasource + """ + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( + user_id=account.id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=user_inputs.get("workspace_id"), + page_id=user_inputs.get("page_id"), + type=user_inputs.get("type"), + ), + provider_type=datasource_type, + ) + ) + try: + variables: dict[str, Any] = {} + for message in online_document_result: + if message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + else: + variables[variable_name] = variable_value + return variables + except Exception as e: + logger.exception("Error during get online document content.") + raise RuntimeError(str(e)) + #TODO Online Drive + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_node_preview.") + raise RuntimeError(str(e)) + def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: From 49bb15fae15981dbfbc43b1503a9f0767f1705ed Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 24 Jun 2025 17:14:31 +0800 Subject: [PATCH 139/155] feat(datasource): add datasource content preview api --- api/controllers/console/__init__.py | 2 +- .../datasets/rag_pipeline/datasource_content_preview.py | 7 ++++--- api/services/rag_pipeline/rag_pipeline.py | 8 +++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 9d9023f59c..22ac835904 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -85,11 +85,11 @@ from .datasets import ( ) from .datasets.rag_pipeline import ( datasource_auth, + datasource_content_preview, rag_pipeline, rag_pipeline_datasets, rag_pipeline_import, rag_pipeline_workflow, - datasource_content_preview ) # Import explore controllers diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 30836b3da1..18885f0dc4 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -3,12 +3,13 @@ from flask_restful import ( # type: ignore reqparse, ) from werkzeug.exceptions import Forbidden + +from controllers.console import api from controllers.console.datasets.wraps import get_rag_pipeline -from controllers.console.wraps import setup_required, account_initialization_required -from libs.login import login_required, current_user +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline -from controllers.console import api from services.rag_pipeline.rag_pipeline import RagPipelineService diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 842676e29a..6427c526d6 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -3,9 +3,9 @@ import logging import re import threading import time -from collections.abc import Callable, Generator, Sequence +from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, cast, Mapping +from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user @@ -16,9 +16,11 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( + DatasourceMessage, DatasourceProviderType, + GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, - WebsiteCrawlMessage, DatasourceMessage, GetOnlineDocumentPageContentRequest, + WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin From 6aba39a2ddf73393bbb559b713ea956268937ae0 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 24 Jun 2025 17:43:25 +0800 Subject: [PATCH 140/155] feat(datasource): add datasource content preview api --- .../datasets/rag_pipeline/datasource_content_preview.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 18885f0dc4..32b5f68364 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -38,7 +38,7 @@ class DataSourceContentPreviewApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - return rag_pipeline_service.run_datasource_node_preview( + preview_content = rag_pipeline_service.run_datasource_node_preview( pipeline=pipeline, node_id=node_id, user_inputs=inputs, @@ -46,6 +46,7 @@ class DataSourceContentPreviewApi(Resource): datasource_type=datasource_type, is_published=True, ) + return preview_content, 200 api.add_resource( DataSourceContentPreviewApi, From 7b7cdad1d8a1e0fc93393f1926c9a0b98590fba6 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 25 Jun 2025 13:28:08 +0800 Subject: [PATCH 141/155] r2 --- api/core/plugin/impl/oauth.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c8..7b984da922 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def get_credentials( self, @@ -49,7 +54,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +63,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +72,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +89,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() @@ -95,4 +105,4 @@ class OAuthHandler(BasePluginClient): if body: raw_data += body - return raw_data + return raw_data \ No newline at end of file From efccbe4039dc8c5f8b66cd51e40e6972c9417ca7 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 25 Jun 2025 17:32:26 +0800 Subject: [PATCH 142/155] r2 --- api/core/app/app_config/entities.py | 14 ++++++++++--- .../variables/manager.py | 20 ++++++++++++++++++- .../apps/pipeline/pipeline_config_manager.py | 6 +++--- .../app/apps/pipeline/pipeline_generator.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 15 ++++++++++---- api/core/variables/variables.py | 5 +++++ api/core/workflow/entities/variable_pool.py | 12 ++++++----- .../nodes/datasource/datasource_node.py | 3 +++ .../knowledge_index/knowledge_index_node.py | 4 ++++ api/models/workflow.py | 8 ++++++++ api/services/dataset_service.py | 5 ++--- 11 files changed, 74 insertions(+), 20 deletions(-) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 5308339871..fe7c75ce96 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,3 +1,4 @@ +from ast import Str from collections.abc import Sequence from enum import Enum, StrEnum from typing import Any, Literal, Optional @@ -113,9 +114,9 @@ class VariableEntity(BaseModel): hide: bool = False max_length: Optional[int] = None options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list) + allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list) + allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list) @field_validator("description", mode="before") @classmethod @@ -127,6 +128,13 @@ class VariableEntity(BaseModel): def convert_none_options(cls, v: Any) -> Sequence[str]: return v or [] +class RagPipelineVariableEntity(VariableEntity): + """ + Rag Pipeline Variable Entity. + """ + tooltips: Optional[str] = None + placeholder: Optional[str] = None + belong_to_node_id: str class ExternalDataVariableEntity(BaseModel): """ diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 2f1da38082..b2530ec422 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,4 +1,6 @@ -from core.app.app_config.entities import VariableEntity +from typing import Any + +from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from models.workflow import Workflow @@ -20,3 +22,19 @@ class WorkflowVariablesConfigManager: variables.append(VariableEntity.model_validate(variable)) return variables + + @classmethod + def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]: + """ + Convert workflow start variables to variables + + :param workflow: workflow instance + """ + variables = [] + + user_input_form = workflow.rag_pipeline_user_input_form() + # variables + for variable in user_input_form: + variables.append(RagPipelineVariableEntity.model_validate(variable)) + + return variables diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index ddf87eacbb..f410457bc6 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -1,6 +1,6 @@ from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager -from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager @@ -13,7 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig): """ Pipeline Config Entity. """ - + rag_pipeline_variables: list[RagPipelineVariableEntity] = [] pass @@ -25,7 +25,7 @@ class PipelineConfigManager(BaseAppConfigManager): app_id=pipeline.id, app_mode=AppMode.RAG_PIPELINE, workflow_id=workflow.id, - variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow), ) return pipeline_config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index de7156129a..7c0bbc46d9 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -160,7 +160,7 @@ class PipelineGenerator(BaseAppGenerator): document_id=document_id, inputs=self._prepare_user_inputs( user_inputs=inputs, - variables=pipeline_config.variables, + variables=pipeline_config.rag_pipeline_variables, tenant_id=pipeline.tenant_id, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, ), diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 50dc8d8fad..402fd92358 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -10,7 +10,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, ) -from core.variables.variables import RAGPipelineVariable +from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -45,6 +45,8 @@ class PipelineRunner(WorkflowBasedAppRunner): self.queue_manager = queue_manager self.workflow_thread_pool_id = workflow_thread_pool_id + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id def run(self) -> None: """ Run application @@ -107,15 +109,20 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, } - rag_pipeline_variables = {} + rag_pipeline_variables = [] if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: rag_pipeline_variable = RAGPipelineVariable(**v) if ( - rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id + (rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared") and rag_pipeline_variable.variable in inputs ): - rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=inputs[rag_pipeline_variable.variable], + ) + ) variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 54aeec61e9..e5dc226571 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -117,3 +117,8 @@ class RAGPipelineVariable(BaseModel): ) required: bool = Field(description="optional, default false", default=False) options: list[str] | None = Field(default_factory=list) + + +class RAGPipelineVariableInput(BaseModel): + variable: RAGPipelineVariable + value: Any diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index e6196f48fe..37f194e0af 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -9,7 +9,9 @@ from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.variables.variables import RAGPipelineVariableInput +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \ + SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID from core.workflow.enums import SystemVariableKey from factories import variable_factory @@ -44,9 +46,9 @@ class VariablePool(BaseModel): description="Conversation variables.", default_factory=list, ) - rag_pipeline_variables: Mapping[str, Any] = Field( + rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( description="RAG pipeline variables.", - default_factory=dict, + default_factory=list, ) def model_post_init(self, context: Any, /) -> None: @@ -59,8 +61,8 @@ class VariablePool(BaseModel): for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) # Add rag pipeline variables to the variable pool - for var, value in self.rag_pipeline_variables.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value) + for var in self.rag_pipeline_variables: + self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value) def add(self, selector: Sequence[str], value: Any, /) -> None: """ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 5c1d8523ff..1ba9cc2645 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -436,3 +436,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): inputs=parameters_for_log, ) ) + @classmethod + def version(cls) -> str: + return "1" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 2c45bf4073..ad89a7ad08 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -159,3 +159,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]: index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() return index_processor.format_preview(chunks) + + @classmethod + def version(cls) -> str: + return "1" diff --git a/api/models/workflow.py b/api/models/workflow.py index 645089ae7f..3c87903bb3 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -322,6 +322,14 @@ class Workflow(Base): return variables + def rag_pipeline_user_input_form(self) -> list: + + # get user_input_form from start node + variables: list[Any] = self.rag_pipeline_variables + + return variables + + @property def unique_hash(self) -> str: """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f83359d456..4a7620bd15 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -344,8 +344,7 @@ class DatasetService: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") - - # check if dataset name is exists + # check if dataset name is exists if ( db.session.query(Dataset) .filter( @@ -471,7 +470,7 @@ class DatasetService: filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] - # update icon info + # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") From eee72101f47b4fe1db8a07ac6db27b5144362b1a Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 27 Jun 2025 16:41:39 +0800 Subject: [PATCH 143/155] feat(online_driver): add online driver plugin, support browsing and downloading --- .../entities/datasource_entities.py | 55 ++++++++++++++ .../online_driver/online_driver_plugin.py | 73 ++++++++++++++++++ .../online_driver/online_driver_provider.py | 48 ++++++++++++ api/core/plugin/impl/datasource.py | 75 +++++++++++++++++++ 4 files changed, 251 insertions(+) create mode 100644 api/core/datasource/online_driver/online_driver_plugin.py create mode 100644 api/core/datasource/online_driver/online_driver_provider.py diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 75b4dd807f..a345b0e18f 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -26,6 +26,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DOCUMENT = "online_document" LOCAL_FILE = "local_file" WEBSITE_CRAWL = "website_crawl" + ONLINE_DRIVER = "online_driver" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -303,3 +304,57 @@ class WebsiteCrawlMessage(BaseModel): class DatasourceMessage(ToolInvokeMessage): pass + + +######################### +# Online driver file +######################### + + +class OnlineDriverFile(BaseModel): + """ + Online driver file + """ + + key: str = Field(..., description="The key of the file") + size: int = Field(..., description="The size of the file") + + +class OnlineDriverFileBucket(BaseModel): + """ + Online driver file bucket + """ + + bucket: Optional[str] = Field(None, description="The bucket of the file") + files: list[OnlineDriverFile] = Field(..., description="The files of the bucket") + is_truncated: bool = Field(False, description="Whether the bucket has more files") + + +class OnlineDriverBrowseFilesRequest(BaseModel): + """ + Get online driver file list request + """ + + prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'") + bucket: Optional[str] = Field(None, description="Storage bucket name") + max_keys: int = Field(20, description="Maximum number of files to return") + start_after: Optional[str] = Field( + None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'" + ) + + +class OnlineDriverBrowseFilesResponse(BaseModel): + """ + Get online driver file list response + """ + + result: list[OnlineDriverFileBucket] = Field(..., description="The bucket of the files") + + +class OnlineDriverDownloadFileRequest(BaseModel): + """ + Get online driver file + """ + + key: str = Field(..., description="The name of the file") + bucket: Optional[str] = Field(None, description="The name of the bucket") diff --git a/api/core/datasource/online_driver/online_driver_plugin.py b/api/core/datasource/online_driver/online_driver_plugin.py new file mode 100644 index 0000000000..60322457ac --- /dev/null +++ b/api/core/datasource/online_driver/online_driver_plugin.py @@ -0,0 +1,73 @@ +from collections.abc import Generator + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceMessage, + DatasourceProviderType, + OnlineDriverBrowseFilesRequest, + OnlineDriverBrowseFilesResponse, + OnlineDriverDownloadFileRequest, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class OnlineDriverDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def online_driver_browse_files( + self, + user_id: str, + request: OnlineDriverBrowseFilesRequest, + provider_type: str, + ) -> Generator[OnlineDriverBrowseFilesResponse, None, None]: + manager = PluginDatasourceManager() + + return manager.online_driver_browse_files( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def online_driver_download_file( + self, + user_id: str, + request: OnlineDriverDownloadFileRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + manager = PluginDatasourceManager() + + return manager.online_driver_download_file( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.ONLINE_DRIVER diff --git a/api/core/datasource/online_driver/online_driver_provider.py b/api/core/datasource/online_driver/online_driver_provider.py new file mode 100644 index 0000000000..edceeecd00 --- /dev/null +++ b/api/core/datasource/online_driver/online_driver_provider.py @@ -0,0 +1,48 @@ +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_driver.online_driver_plugin import OnlineDriverDatasourcePlugin + + +class OnlineDriverDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity, tenant_id) + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.ONLINE_DRIVER + + def get_datasource(self, datasource_name: str) -> OnlineDriverDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return OnlineDriverDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index a8e98d2c1a..f38ea0555f 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -5,6 +5,9 @@ from core.datasource.entities.datasource_entities import ( DatasourceMessage, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, + OnlineDriverBrowseFilesRequest, + OnlineDriverBrowseFilesResponse, + OnlineDriverDownloadFileRequest, WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -191,6 +194,78 @@ class PluginDatasourceManager(BasePluginClient): ) yield from response + def online_driver_browse_files( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + request: OnlineDriverBrowseFilesRequest, + provider_type: str, + ) -> Generator[OnlineDriverBrowseFilesResponse, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/online_driver_browse_files", + OnlineDriverBrowseFilesResponse, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "request": request.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + yield from response + + def online_driver_download_file( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + request: OnlineDriverDownloadFileRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/online_driver_download_file", + DatasourceMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "request": request.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + yield from response + def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] ) -> bool: From 1449ed86c4ab0a264ea49410b62071ed6d2f28c9 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 27 Jun 2025 20:11:28 +0800 Subject: [PATCH 144/155] feat: rename online driver to online drive and update related classes and methods :) --- .../entities/datasource_entities.py | 16 ++++++------- .../online_drive_plugin.py} | 24 +++++++++---------- .../online_drive_provider.py} | 10 ++++---- api/core/plugin/impl/datasource.py | 22 ++++++++--------- 4 files changed, 36 insertions(+), 36 deletions(-) rename api/core/datasource/{online_driver/online_driver_plugin.py => online_drive/online_drive_plugin.py} (76%) rename api/core/datasource/{online_driver/online_driver_provider.py => online_drive/online_drive_provider.py} (83%) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index a345b0e18f..41be2dcc3d 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -26,7 +26,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DOCUMENT = "online_document" LOCAL_FILE = "local_file" WEBSITE_CRAWL = "website_crawl" - ONLINE_DRIVER = "online_driver" + ONLINE_DRIVE = "online_drive" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -311,7 +311,7 @@ class DatasourceMessage(ToolInvokeMessage): ######################### -class OnlineDriverFile(BaseModel): +class OnlineDriveFile(BaseModel): """ Online driver file """ @@ -320,17 +320,17 @@ class OnlineDriverFile(BaseModel): size: int = Field(..., description="The size of the file") -class OnlineDriverFileBucket(BaseModel): +class OnlineDriveFileBucket(BaseModel): """ Online driver file bucket """ bucket: Optional[str] = Field(None, description="The bucket of the file") - files: list[OnlineDriverFile] = Field(..., description="The files of the bucket") + files: list[OnlineDriveFile] = Field(..., description="The files of the bucket") is_truncated: bool = Field(False, description="Whether the bucket has more files") -class OnlineDriverBrowseFilesRequest(BaseModel): +class OnlineDriveBrowseFilesRequest(BaseModel): """ Get online driver file list request """ @@ -343,15 +343,15 @@ class OnlineDriverBrowseFilesRequest(BaseModel): ) -class OnlineDriverBrowseFilesResponse(BaseModel): +class OnlineDriveBrowseFilesResponse(BaseModel): """ Get online driver file list response """ - result: list[OnlineDriverFileBucket] = Field(..., description="The bucket of the files") + result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files") -class OnlineDriverDownloadFileRequest(BaseModel): +class OnlineDriveDownloadFileRequest(BaseModel): """ Get online driver file """ diff --git a/api/core/datasource/online_driver/online_driver_plugin.py b/api/core/datasource/online_drive/online_drive_plugin.py similarity index 76% rename from api/core/datasource/online_driver/online_driver_plugin.py rename to api/core/datasource/online_drive/online_drive_plugin.py index 60322457ac..f0e3cb38f9 100644 --- a/api/core/datasource/online_driver/online_driver_plugin.py +++ b/api/core/datasource/online_drive/online_drive_plugin.py @@ -6,14 +6,14 @@ from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceMessage, DatasourceProviderType, - OnlineDriverBrowseFilesRequest, - OnlineDriverBrowseFilesResponse, - OnlineDriverDownloadFileRequest, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, ) from core.plugin.impl.datasource import PluginDatasourceManager -class OnlineDriverDatasourcePlugin(DatasourcePlugin): +class OnlineDriveDatasourcePlugin(DatasourcePlugin): tenant_id: str icon: str plugin_unique_identifier: str @@ -33,15 +33,15 @@ class OnlineDriverDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def online_driver_browse_files( + def online_drive_browse_files( self, user_id: str, - request: OnlineDriverBrowseFilesRequest, + request: OnlineDriveBrowseFilesRequest, provider_type: str, - ) -> Generator[OnlineDriverBrowseFilesResponse, None, None]: + ) -> Generator[OnlineDriveBrowseFilesResponse, None, None]: manager = PluginDatasourceManager() - return manager.online_driver_browse_files( + return manager.online_drive_browse_files( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, @@ -51,15 +51,15 @@ class OnlineDriverDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) - def online_driver_download_file( + def online_drive_download_file( self, user_id: str, - request: OnlineDriverDownloadFileRequest, + request: OnlineDriveDownloadFileRequest, provider_type: str, ) -> Generator[DatasourceMessage, None, None]: manager = PluginDatasourceManager() - return manager.online_driver_download_file( + return manager.online_drive_download_file( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, @@ -70,4 +70,4 @@ class OnlineDriverDatasourcePlugin(DatasourcePlugin): ) def datasource_provider_type(self) -> str: - return DatasourceProviderType.ONLINE_DRIVER + return DatasourceProviderType.ONLINE_DRIVE diff --git a/api/core/datasource/online_driver/online_driver_provider.py b/api/core/datasource/online_drive/online_drive_provider.py similarity index 83% rename from api/core/datasource/online_driver/online_driver_provider.py rename to api/core/datasource/online_drive/online_drive_provider.py index edceeecd00..d0923ed807 100644 --- a/api/core/datasource/online_driver/online_driver_provider.py +++ b/api/core/datasource/online_drive/online_drive_provider.py @@ -1,10 +1,10 @@ from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType -from core.datasource.online_driver.online_driver_plugin import OnlineDriverDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin -class OnlineDriverDatasourcePluginProviderController(DatasourcePluginProviderController): +class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin plugin_id: str plugin_unique_identifier: str @@ -21,9 +21,9 @@ class OnlineDriverDatasourcePluginProviderController(DatasourcePluginProviderCon """ returns the type of the provider """ - return DatasourceProviderType.ONLINE_DRIVER + return DatasourceProviderType.ONLINE_DRIVE - def get_datasource(self, datasource_name: str) -> OnlineDriverDatasourcePlugin: # type: ignore + def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -39,7 +39,7 @@ class OnlineDriverDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return OnlineDriverDatasourcePlugin( + return OnlineDriveDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index f38ea0555f..4d90685d24 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -5,9 +5,9 @@ from core.datasource.entities.datasource_entities import ( DatasourceMessage, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, - OnlineDriverBrowseFilesRequest, - OnlineDriverBrowseFilesResponse, - OnlineDriverDownloadFileRequest, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -194,16 +194,16 @@ class PluginDatasourceManager(BasePluginClient): ) yield from response - def online_driver_browse_files( + def online_drive_browse_files( self, tenant_id: str, user_id: str, datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - request: OnlineDriverBrowseFilesRequest, + request: OnlineDriveBrowseFilesRequest, provider_type: str, - ) -> Generator[OnlineDriverBrowseFilesResponse, None, None]: + ) -> Generator[OnlineDriveBrowseFilesResponse, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -212,8 +212,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/online_driver_browse_files", - OnlineDriverBrowseFilesResponse, + f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files", + OnlineDriveBrowseFilesResponse, data={ "user_id": user_id, "data": { @@ -230,14 +230,14 @@ class PluginDatasourceManager(BasePluginClient): ) yield from response - def online_driver_download_file( + def online_drive_download_file( self, tenant_id: str, user_id: str, datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - request: OnlineDriverDownloadFileRequest, + request: OnlineDriveDownloadFileRequest, provider_type: str, ) -> Generator[DatasourceMessage, None, None]: """ @@ -248,7 +248,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/online_driver_download_file", + f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file", DatasourceMessage, data={ "user_id": user_id, From 618ad4c2913fa3c296fc5337dbb2a971c0ae5d54 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 30 Jun 2025 15:36:20 +0800 Subject: [PATCH 145/155] r2 --- api/fields/dataset_fields.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index aa147331d4..79a4f1c6de 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -56,7 +56,12 @@ external_knowledge_info_fields = { doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} -icon_info_fields = {"icon_type": fields.String, "icon": fields.String, "icon_background": fields.String} +icon_info_fields = { + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": fields.String, +} dataset_detail_fields = { "id": fields.String, From cdbba1400c25c521aa9ba2ffcec1ef09b1eda408 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 1 Jul 2025 11:57:06 +0800 Subject: [PATCH 146/155] feat(datasource): update fetch_datasource_provider --- api/core/plugin/impl/datasource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 4d90685d24..97c9428bde 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -10,7 +10,7 @@ from core.datasource.entities.datasource_entities import ( OnlineDriveDownloadFileRequest, WebsiteCrawlMessage, ) -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, @@ -61,7 +61,7 @@ class PluginDatasourceManager(BasePluginClient): if provider_id == "langgenius/file/file": return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - tool_provider_id = ToolProviderID(provider_id) + tool_provider_id = DatasourceProviderID(provider_id) def transformer(json_response: dict[str, Any]) -> dict: data = json_response.get("data") From bfcf09b684a41819961ea3956bbce382a17ec117 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 1 Jul 2025 14:04:09 +0800 Subject: [PATCH 147/155] feat(datasource): fix datasource icon --- api/core/plugin/impl/datasource.py | 3 +++ api/services/tools/tools_transform_service.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 97c9428bde..6dc5ebca6b 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -16,6 +16,7 @@ from core.plugin.entities.plugin_daemon import ( PluginDatasourceProviderEntity, ) from core.plugin.impl.base import BasePluginClient +from services.tools.tools_transform_service import ToolTransformService class PluginDatasourceManager(BasePluginClient): @@ -43,6 +44,8 @@ class PluginDatasourceManager(BasePluginClient): ) local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + for provider in response: + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) all_response = [local_file_datasource_provider] + response for provider in all_response: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..8a73c73a1b 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -56,7 +57,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]): + def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): """ repack provider @@ -77,6 +78,17 @@ class ToolTransformService: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) + elif isinstance(provider, PluginDatasourceProviderEntity): + if provider.plugin_id: + if isinstance(provider.declaration.identity.icon, str): + provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.declaration.identity.icon + ) + else: + provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, + icon=provider.declaration.identity.icon + ) @classmethod def builtin_provider_to_user_provider( From f44f0fa34cf61a8e8cee8efe7eb1f350567a924b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 1 Jul 2025 14:23:46 +0800 Subject: [PATCH 148/155] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 8 +- api/services/rag_pipeline/rag_pipeline.py | 78 ++++++++++++++----- 2 files changed, 64 insertions(+), 22 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 28ab4b1635..3ef0c42d0f 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -804,7 +804,7 @@ class PublishedRagPipelineSecondStepApi(Resource): if not node_id: raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { "variables": variables, } @@ -829,7 +829,7 @@ class PublishedRagPipelineFirstStepApi(Resource): if not node_id: raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_published_first_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { "variables": variables, } @@ -854,7 +854,7 @@ class DraftRagPipelineFirstStepApi(Resource): if not node_id: raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_draft_first_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) return { "variables": variables, } @@ -880,7 +880,7 @@ class DraftRagPipelineSecondStepApi(Resource): raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) return { "variables": variables, } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 26036dc2c5..f379a4b930 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -113,6 +113,14 @@ class RagPipelineService: ) if not customized_template: raise ValueError("Customized pipeline template not found.") + # check template name is exist + template_name = template_info.name + if template_name: + template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.id != template_id).first() + if template: + raise ValueError("Template name is already exists") customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() @@ -785,7 +793,7 @@ class RagPipelineService: break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) + variables = published_workflow.rag_pipeline_variables if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -793,29 +801,29 @@ class RagPipelineService: datasource_parameters = datasource_node_data.get("datasource_parameters", {}) user_input_variables = [] for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) + if value.get("value") and isinstance(value.get("value"), str): + if re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(variables_map.get(key, {})) return user_input_variables - def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get first step parameters of rag pipeline """ - draft_workflow = self.get_draft_workflow(pipeline=pipeline) - if not draft_workflow: + workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + if not workflow: raise ValueError("Workflow not initialized") - # get second step node datasource_node_data = None - datasource_nodes = draft_workflow.graph_dict.get("nodes", []) + datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) + variables = workflow.rag_pipeline_variables if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -824,16 +832,21 @@ class RagPipelineService: user_input_variables = [] for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split('.')[-1] + user_input_variables.append(variables_map.get(last_part, {})) return user_input_variables - def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get second step parameters of rag pipeline """ - workflow = self.get_draft_workflow(pipeline=pipeline) + workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) if not workflow: raise ValueError("Workflow not initialized") @@ -841,13 +854,32 @@ class RagPipelineService: rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: return [] + variables_map = {item["variable"]: item for item in rag_pipeline_variables} - # get datasource provider + # get datasource node data + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if datasource_node_data: + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split('.')[-1] + variables_map.pop(last_part) + all_second_step_variables = list(variables_map.values()) datasource_provider_variables = [ - item - for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" - ] + item + for item in all_second_step_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: @@ -968,6 +1000,16 @@ class RagPipelineService: dataset = pipeline.dataset if not dataset: raise ValueError("Dataset not found") + + # check template name is exist + template_name = args.get("name") + if template_name: + template = db.session.query(PipelineCustomizedTemplate).filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, + ).first() + if template: + raise ValueError("Template name is already exists") max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position)) From a4eddd7dc2cbc4b7711e368143871ec067156cfb Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 1 Jul 2025 15:16:33 +0800 Subject: [PATCH 149/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 57 +---------------------- 1 file changed, 2 insertions(+), 55 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f379a4b930..a2551f043d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -753,59 +753,6 @@ class RagPipelineService: return workflow - def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: - """ - Get second step parameters of rag pipeline - """ - - workflow = self.get_published_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") - - # get second step node - rag_pipeline_variables = workflow.rag_pipeline_variables - if not rag_pipeline_variables: - return [] - - # get datasource provider - datasource_provider_variables = [ - item - for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" - ] - return datasource_provider_variables - - def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: - """ - Get first step parameters of rag pipeline - """ - - published_workflow = self.get_published_workflow(pipeline=pipeline) - if not published_workflow: - raise ValueError("Workflow not initialized") - - # get second step node - datasource_node_data = None - datasource_nodes = published_workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - variables = published_workflow.rag_pipeline_variables - if variables: - variables_map = {item["variable"]: item for item in variables} - else: - return [] - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - user_input_variables = [] - for key, value in datasource_parameters.items(): - if value.get("value") and isinstance(value.get("value"), str): - if re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) - return user_input_variables - def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get first step parameters of rag pipeline @@ -833,7 +780,7 @@ class RagPipelineService: user_input_variables = [] for key, value in datasource_parameters.items(): if value.get("value") and isinstance(value.get("value"), str): - pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" match = re.match(pattern, value["value"]) if match: full_path = match.group(1) @@ -868,7 +815,7 @@ class RagPipelineService: for key, value in datasource_parameters.items(): if value.get("value") and isinstance(value.get("value"), str): - pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" match = re.match(pattern, value["value"]) if match: full_path = match.group(1) From f33b6c0c73729d3d0a275a4079451b80891b1a3b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 1 Jul 2025 16:08:54 +0800 Subject: [PATCH 150/155] add online drive --- .../nodes/datasource/datasource_node.py | 28 ++++++++++++++++- api/services/rag_pipeline/rag_pipeline.py | 30 +++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 1ba9cc2645..77eba1a7ce 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -9,8 +9,10 @@ from core.datasource.entities.datasource_entities import ( DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, + OnlineDriveDownloadFileRequest, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.file import File from core.file.enums import FileTransferMethod, FileType @@ -125,7 +127,31 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): parameters_for_log=parameters_for_log, datasource_info=datasource_info, ) - + case DatasourceProviderType.ONLINE_DRIVE: + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + online_drive_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.online_drive_download_file( + user_id=self.user_id, + request=OnlineDriveDownloadFileRequest( + key=datasource_info.get("key"), + bucket=datasource_info.get("bucket"), + ), + provider_type=datasource_type, + ) + ) + yield from self._transform_message( + messages=online_drive_result, + parameters_for_log=parameters_for_log, + datasource_info=datasource_info, + ) case DatasourceProviderType.WEBSITE_CRAWL: yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a2551f043d..9fe09744a8 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -17,9 +17,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -554,10 +557,33 @@ class RagPipelineService: end_time = time.time() online_document_event = DatasourceCompletedEvent( data=message.result, - time_consuming=round(end_time - start_time, 2) + time_consuming=round(end_time - start_time, 2), + total=None, + completed=None, ) yield online_document_event.model_dump() - + case DatasourceProviderType.ONLINE_DRIVE: + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files( + user_id=account.id, + request=OnlineDriveBrowseFilesRequest( + bucket=user_inputs.get("bucket"), + prefix=user_inputs.get("prefix"), + max_keys=user_inputs.get("max_keys", 20), + start_after=user_inputs.get("start_after"), + ), + provider_type=datasource_runtime.datasource_provider_type(), + ) + start_time = time.time() + for message in online_drive_result: + end_time = time.time() + online_drive_event = DatasourceCompletedEvent( + data=message.result, + time_consuming=round(end_time - start_time, 2), + total=None, + completed=None, + ) + yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( From 39d3f58082f977eb4ba9c632fde18d9d824e1d94 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 2 Jul 2025 11:33:00 +0800 Subject: [PATCH 151/155] r2 --- .../datasets/rag_pipeline/datasource_auth.py | 2 ++ ..._15_1558-b35c3db83d09_add_pipeline_info.py | 2 +- ...2_1132-15e40b74a6d2_add_pipeline_info_9.py | 33 +++++++++++++++++++ api/models/oauth.py | 1 + api/services/datasource_provider_service.py | 13 +++++++- .../customized/customized_retrieval.py | 1 + 6 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 21a7b998f0..7f7b6a7867 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -96,6 +96,7 @@ class DatasourceAuth(Resource): parser = reqparse.RequestParser() parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -108,6 +109,7 @@ class DatasourceAuth(Resource): provider=args["provider"], plugin_id=args["plugin_id"], credentials=args["credentials"], + name=args["name"], ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py index 503842b797..961589a87e 100644 --- a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'b35c3db83d09' -down_revision = '4474872b0ee6' +down_revision = '0ab65e1cc7fa' branch_labels = None depends_on = None diff --git a/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py b/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py new file mode 100644 index 0000000000..82c5991775 --- /dev/null +++ b/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_9 + +Revision ID: 15e40b74a6d2 +Revises: a1025f709c06 +Create Date: 2025-07-02 11:32:44.125790 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '15e40b74a6d2' +down_revision = 'a1025f709c06' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(length=255), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_column('name') + + # ### end Alembic commands ### diff --git a/api/models/oauth.py b/api/models/oauth.py index b1b09e5d45..84bc29931e 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -29,6 +29,7 @@ class DatasourceProvider(Base): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index fa01fe0afe..bca0081417 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -22,7 +22,7 @@ class DatasourceProviderService: self.provider_manager = PluginDatasourceManager() def datasource_provider_credentials_validate( - self, tenant_id: str, provider: str, plugin_id: str, credentials: dict + self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str ) -> None: """ validate datasource provider credentials. @@ -31,6 +31,15 @@ class DatasourceProviderService: :param provider: :param credentials: """ + # check name is exist + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, name=name) + .first() + ) + if datasource_provider: + raise ValueError("Authorization name is already exists") + credential_valid = self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, @@ -55,6 +64,7 @@ class DatasourceProviderService: credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider = DatasourceProvider( tenant_id=tenant_id, + name=name, provider=provider, plugin_id=plugin_id, auth_type="api_key", @@ -120,6 +130,7 @@ class DatasourceProviderService: { "credentials": copy_credentials, "type": datasource_provider.auth_type, + "name": datasource_provider.name, } ) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 7280408889..3380d23ec4 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -38,6 +38,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): pipeline_customized_templates = ( db.session.query(PipelineCustomizedTemplate) .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) + .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) .all() ) recommended_pipelines_results = [] From 9f14b5db9af04dd02942065f538bdb679171c86b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 2 Jul 2025 11:55:21 +0800 Subject: [PATCH 152/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 28 ++--------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9fe09744a8..da6ba0fba5 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -119,7 +119,7 @@ class RagPipelineService: # check template name is exist template_name = template_info.name if template_name: - template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name, + template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.id != template_id).first() if template: @@ -558,32 +558,8 @@ class RagPipelineService: online_document_event = DatasourceCompletedEvent( data=message.result, time_consuming=round(end_time - start_time, 2), - total=None, - completed=None, ) yield online_document_event.model_dump() - case DatasourceProviderType.ONLINE_DRIVE: - datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) - online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files( - user_id=account.id, - request=OnlineDriveBrowseFilesRequest( - bucket=user_inputs.get("bucket"), - prefix=user_inputs.get("prefix"), - max_keys=user_inputs.get("max_keys", 20), - start_after=user_inputs.get("start_after"), - ), - provider_type=datasource_runtime.datasource_provider_type(), - ) - start_time = time.time() - for message in online_drive_result: - end_time = time.time() - online_drive_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2), - total=None, - completed=None, - ) - yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( @@ -973,7 +949,7 @@ class RagPipelineService: dataset = pipeline.dataset if not dataset: raise ValueError("Dataset not found") - + # check template name is exist template_name = args.get("name") if template_name: From 81b07dc3be8e6a5fe89a86cef0d2246feacb5744 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 2 Jul 2025 18:15:23 +0800 Subject: [PATCH 153/155] r2 --- .../console/datasets/datasets_document.py | 16 +++-- .../datasets/rag_pipeline/datasource_auth.py | 2 +- .../datasource_content_preview.py | 3 +- api/core/app/app_config/entities.py | 4 +- .../variables/manager.py | 2 - .../apps/pipeline/pipeline_config_manager.py | 1 + api/core/app/apps/pipeline/pipeline_runner.py | 7 +- api/core/workflow/entities/variable_pool.py | 8 ++- .../nodes/datasource/datasource_node.py | 1 + api/models/workflow.py | 2 - api/services/dataset_service.py | 6 +- api/services/datasource_provider_service.py | 8 +-- api/services/rag_pipeline/rag_pipeline.py | 69 ++++++++++++++++++- api/services/tools/tools_transform_service.py | 5 +- 14 files changed, 102 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 42592c6c9a..35d912bfcc 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1051,11 +1051,12 @@ class DocumentPipelineExecutionLogApi(DocumentResource): .first() ) if not log: - return {"datasource_info": None, - "datasource_type": None, - "input_data": None, - "datasource_node_id": None, - }, 200 + return { + "datasource_info": None, + "datasource_type": None, + "input_data": None, + "datasource_node_id": None, + }, 200 return { "datasource_info": json.loads(log.datasource_info), "datasource_type": log.datasource_type, @@ -1086,5 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets//retry") api.add_resource(DocumentRenameApi, "/datasets//documents//rename") api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") -api.add_resource(DocumentPipelineExecutionLogApi, - "/datasets//documents//pipeline-execution-log") +api.add_resource( + DocumentPipelineExecutionLogApi, "/datasets//documents//pipeline-execution-log" +) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 7f7b6a7867..124d45f513 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -96,7 +96,7 @@ class DatasourceAuth(Resource): parser = reqparse.RequestParser() parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 32b5f68364..bb02c659b8 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -48,7 +48,8 @@ class DataSourceContentPreviewApi(Resource): ) return preview_content, 200 + api.add_resource( DataSourceContentPreviewApi, - "/rag/pipelines//workflows/published/datasource/nodes//preview" + "/rag/pipelines//workflows/published/datasource/nodes//preview", ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index fe7c75ce96..cbb382beb3 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,4 +1,3 @@ -from ast import Str from collections.abc import Sequence from enum import Enum, StrEnum from typing import Any, Literal, Optional @@ -128,14 +127,17 @@ class VariableEntity(BaseModel): def convert_none_options(cls, v: Any) -> Sequence[str]: return v or [] + class RagPipelineVariableEntity(VariableEntity): """ Rag Pipeline Variable Entity. """ + tooltips: Optional[str] = None placeholder: Optional[str] = None belong_to_node_id: str + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index b2530ec422..1c63874ee3 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,5 +1,3 @@ -from typing import Any - from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index f410457bc6..b83fc1800f 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -13,6 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig): """ Pipeline Config Entity. """ + rag_pipeline_variables: list[RagPipelineVariableEntity] = [] pass diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 402fd92358..52afb78ee5 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -47,6 +47,7 @@ class PipelineRunner(WorkflowBasedAppRunner): def _get_app_id(self) -> str: return self.application_generate_entity.app_config.app_id + def run(self) -> None: """ Run application @@ -114,9 +115,9 @@ class PipelineRunner(WorkflowBasedAppRunner): for v in workflow.rag_pipeline_variables: rag_pipeline_variable = RAGPipelineVariable(**v) if ( - (rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared") - and rag_pipeline_variable.variable in inputs - ): + rag_pipeline_variable.belong_to_node_id + in (self.application_generate_entity.start_node_id, "shared") + ) and rag_pipeline_variable.variable in inputs: rag_pipeline_variables.append( RAGPipelineVariableInput( variable=rag_pipeline_variable, diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 37f194e0af..3a68f45f61 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -10,8 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment from core.variables.variables import RAGPipelineVariableInput -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \ - SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from core.workflow.enums import SystemVariableKey from factories import variable_factory diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 77eba1a7ce..01f6f51648 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -462,6 +462,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): inputs=parameters_for_log, ) ) + @classmethod def version(cls) -> str: return "1" diff --git a/api/models/workflow.py b/api/models/workflow.py index 3c87903bb3..638885be8d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -323,13 +323,11 @@ class Workflow(Base): return variables def rag_pipeline_user_input_form(self) -> list: - # get user_input_form from start node variables: list[Any] = self.rag_pipeline_variables return variables - @property def unique_hash(self) -> str: """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4a7620bd15..f7941fa49f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -344,10 +344,10 @@ class DatasetService: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") - # check if dataset name is exists + # check if dataset name is exists if ( db.session.query(Dataset) - .filter( + .filter( Dataset.id != dataset_id, Dataset.name == data.get("name", dataset.name), Dataset.tenant_id == dataset.tenant_id, @@ -470,7 +470,7 @@ class DatasetService: filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] - # update icon info + # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index bca0081417..228c18b7c2 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -32,14 +32,10 @@ class DatasourceProviderService: :param credentials: """ # check name is exist - datasource_provider = ( - db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, name=name) - .first() - ) + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first() if datasource_provider: raise ValueError("Authorization name is already exists") - + credential_valid = self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 6427c526d6..0370826c12 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -20,9 +20,12 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import ( BaseDatasourceEvent, @@ -31,8 +34,9 @@ from core.rag.entities.event import ( DatasourceProcessingEvent, ) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import Variable +from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput, Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -381,6 +385,17 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() + rag_pipeline_variables = [] + if draft_workflow.rag_pipeline_variables: + for v in draft_workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.variable in user_inputs: + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=user_inputs[rag_pipeline_variable.variable], + ) + ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -388,6 +403,12 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, + variable_pool=VariablePool( + user_inputs=user_inputs, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -413,6 +434,17 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() + rag_pipeline_variables = [] + if published_workflow.rag_pipeline_variables: + for v in published_workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.variable in user_inputs: + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=user_inputs[rag_pipeline_variable.variable], + ) + ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -420,6 +452,12 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, + variable_pool=VariablePool( + user_inputs=user_inputs, + environment_variables=published_workflow.environment_variables, + conversation_variables=published_workflow.conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -511,6 +549,33 @@ class RagPipelineService: except Exception as e: logger.exception("Error during online document.") yield DatasourceErrorEvent(error=str(e)).model_dump() + case DatasourceProviderType.ONLINE_DRIVE: + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files( + user_id=account.id, + request=OnlineDriveBrowseFilesRequest( + bucket=user_inputs.get("bucket"), + prefix=user_inputs.get("prefix"), + max_keys=user_inputs.get("max_keys", 20), + start_after=user_inputs.get("start_after"), + ), + provider_type=datasource_runtime.datasource_provider_type(), + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + for message in online_drive_result: + end_time = time.time() + online_drive_event = DatasourceCompletedEvent( + data=message.result, + time_consuming=round(end_time - start_time, 2), + total=None, + completed=None, + ) + yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( @@ -631,7 +696,7 @@ class RagPipelineService: except Exception as e: logger.exception("Error during get online document content.") raise RuntimeError(str(e)) - #TODO Online Drive + # TODO Online Drive case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") except Exception as e: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 8a73c73a1b..282728153a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -86,8 +86,9 @@ class ToolTransformService: ) else: provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, provider_name=provider.name, - icon=provider.declaration.identity.icon + provider_type=provider.type.value, + provider_name=provider.name, + icon=provider.declaration.identity.icon, ) @classmethod From a6ff9b224b801f1521d518574f095aec669dc2a3 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 2 Jul 2025 18:20:41 +0800 Subject: [PATCH 154/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 36 ++++++++++++----------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0370826c12..1f6f2308c0 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -391,7 +391,7 @@ class RagPipelineService: rag_pipeline_variable = RAGPipelineVariable(**v) if rag_pipeline_variable.variable in user_inputs: rag_pipeline_variables.append( - RAGPipelineVariableInput( + RAGPipelineVariableInput( variable=rag_pipeline_variable, value=user_inputs[rag_pipeline_variable.variable], ) @@ -437,14 +437,14 @@ class RagPipelineService: rag_pipeline_variables = [] if published_workflow.rag_pipeline_variables: for v in published_workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.variable in user_inputs: - rag_pipeline_variables.append( - RAGPipelineVariableInput( - variable=rag_pipeline_variable, - value=user_inputs[rag_pipeline_variable.variable], - ) + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.variable in user_inputs: + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=user_inputs[rag_pipeline_variable.variable], ) + ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -551,15 +551,17 @@ class RagPipelineService: yield DatasourceErrorEvent(error=str(e)).model_dump() case DatasourceProviderType.ONLINE_DRIVE: datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) - online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files( - user_id=account.id, - request=OnlineDriveBrowseFilesRequest( - bucket=user_inputs.get("bucket"), - prefix=user_inputs.get("prefix"), - max_keys=user_inputs.get("max_keys", 20), - start_after=user_inputs.get("start_after"), - ), - provider_type=datasource_runtime.datasource_provider_type(), + online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = ( + datasource_runtime.online_drive_browse_files( + user_id=account.id, + request=OnlineDriveBrowseFilesRequest( + bucket=user_inputs.get("bucket"), + prefix=user_inputs.get("prefix"), + max_keys=user_inputs.get("max_keys", 20), + start_after=user_inputs.get("start_after"), + ), + provider_type=datasource_runtime.datasource_provider_type(), + ) ) start_time = time.time() start_event = DatasourceProcessingEvent( From 38d895ab5f13f5b3a45e06c1bf80c9a8ebd63748 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 2 Jul 2025 18:46:36 +0800 Subject: [PATCH 155/155] r2 --- api/services/rag_pipeline/rag_pipeline.py | 200 ++++++++-------------- 1 file changed, 68 insertions(+), 132 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1f6f2308c0..0e1fad600f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -20,12 +20,9 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, - OnlineDriveBrowseFilesRequest, - OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import ( BaseDatasourceEvent, @@ -34,9 +31,8 @@ from core.rag.entities.event import ( DatasourceProcessingEvent, ) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput, Variable +from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -127,6 +123,20 @@ class RagPipelineService: ) if not customized_template: raise ValueError("Customized pipeline template not found.") + # check template name is exist + template_name = template_info.name + if template_name: + template = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.id != template_id, + ) + .first() + ) + if template: + raise ValueError("Template name is already exists") customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() @@ -385,17 +395,6 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() - rag_pipeline_variables = [] - if draft_workflow.rag_pipeline_variables: - for v in draft_workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.variable in user_inputs: - rag_pipeline_variables.append( - RAGPipelineVariableInput( - variable=rag_pipeline_variable, - value=user_inputs[rag_pipeline_variable.variable], - ) - ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -403,12 +402,6 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=draft_workflow.conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, - ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -434,17 +427,6 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() - rag_pipeline_variables = [] - if published_workflow.rag_pipeline_variables: - for v in published_workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.variable in user_inputs: - rag_pipeline_variables.append( - RAGPipelineVariableInput( - variable=rag_pipeline_variable, - value=user_inputs[rag_pipeline_variable.variable], - ) - ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -452,12 +434,6 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - user_inputs=user_inputs, - environment_variables=published_workflow.environment_variables, - conversation_variables=published_workflow.conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, - ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -549,35 +525,6 @@ class RagPipelineService: except Exception as e: logger.exception("Error during online document.") yield DatasourceErrorEvent(error=str(e)).model_dump() - case DatasourceProviderType.ONLINE_DRIVE: - datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) - online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = ( - datasource_runtime.online_drive_browse_files( - user_id=account.id, - request=OnlineDriveBrowseFilesRequest( - bucket=user_inputs.get("bucket"), - prefix=user_inputs.get("prefix"), - max_keys=user_inputs.get("max_keys", 20), - start_after=user_inputs.get("start_after"), - ), - provider_type=datasource_runtime.datasource_provider_type(), - ) - ) - start_time = time.time() - start_event = DatasourceProcessingEvent( - total=0, - completed=0, - ) - yield start_event.model_dump() - for message in online_drive_result: - end_time = time.time() - online_drive_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2), - total=None, - completed=None, - ) - yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( @@ -874,77 +821,26 @@ class RagPipelineService: return workflow - def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ - Get second step parameters of rag pipeline + Get first step parameters of rag pipeline """ - workflow = self.get_published_workflow(pipeline=pipeline) + workflow = ( + self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + ) if not workflow: raise ValueError("Workflow not initialized") - # get second step node - rag_pipeline_variables = workflow.rag_pipeline_variables - if not rag_pipeline_variables: - return [] - - # get datasource provider - datasource_provider_variables = [ - item - for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" - ] - return datasource_provider_variables - - def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: - """ - Get first step parameters of rag pipeline - """ - - published_workflow = self.get_published_workflow(pipeline=pipeline) - if not published_workflow: - raise ValueError("Workflow not initialized") - - # get second step node datasource_node_data = None - datasource_nodes = published_workflow.graph_dict.get("nodes", []) + datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) - if variables: - variables_map = {item["variable"]: item for item in variables} - else: - return [] - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - user_input_variables = [] - for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) - return user_input_variables - - def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: - """ - Get first step parameters of rag pipeline - """ - - draft_workflow = self.get_draft_workflow(pipeline=pipeline) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - # get second step node - datasource_node_data = None - datasource_nodes = draft_workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) + variables = workflow.rag_pipeline_variables if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -953,16 +849,23 @@ class RagPipelineService: user_input_variables = [] for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + user_input_variables.append(variables_map.get(last_part, {})) return user_input_variables - def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get second step parameters of rag pipeline """ - workflow = self.get_draft_workflow(pipeline=pipeline) + workflow = ( + self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + ) if not workflow: raise ValueError("Workflow not initialized") @@ -970,11 +873,30 @@ class RagPipelineService: rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: return [] + variables_map = {item["variable"]: item for item in rag_pipeline_variables} - # get datasource provider + # get datasource node data + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if datasource_node_data: + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + variables_map.pop(last_part) + all_second_step_variables = list(variables_map.values()) datasource_provider_variables = [ item - for item in rag_pipeline_variables + for item in all_second_step_variables if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" ] return datasource_provider_variables @@ -1098,6 +1020,20 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") + # check template name is exist + template_name = args.get("name") + if template_name: + template = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, + ) + .first() + ) + if template: + raise ValueError("Template name is already exists") + max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position)) .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)