diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index ebbd7d317b..34faa4ec85 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,5 +1,6 @@ from flask_login import current_user # type: ignore # type: ignore from flask_restx import Resource, marshal, reqparse # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services @@ -10,6 +11,7 @@ from controllers.console.wraps import ( cloud_edition_billing_rate_limit_check, setup_required, ) +from extensions.ext_database import db from fields.dataset_fields import dataset_detail_fields from libs.login import login_required from models.dataset import DatasetPermissionEnum @@ -64,10 +66,12 @@ class CreateRagPipelineDatasetApi(Resource): yaml_content=args["yaml_content"], ) try: - import_info = RagPipelineDslService.create_rag_pipeline_dataset( - tenant_id=current_user.current_tenant_id, - rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, - ) + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( + tenant_id=current_user.current_tenant_id, + rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, + ) if rag_pipeline_dataset_create_entity.permission == "partial_members": DatasetPermissionService.update_partial_member_list( current_user.current_tenant_id, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 99c261bd89..5a9921aca7 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -110,9 +110,11 @@ class PipelineGenerator(BaseAppGenerator): workflow_thread_pool_id: Optional[str] = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset - dataset = pipeline.dataset - if not dataset: - raise ValueError("Pipeline dataset is required") + + with Session(db.engine) as session: + dataset = pipeline.retrieve_dataset(session) + if not dataset: + raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] @@ -360,9 +362,10 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") ) - dataset = pipeline.dataset - if not dataset: - raise ValueError("Pipeline dataset is required") + with Session(db.engine) as session: + dataset = pipeline.retrieve_dataset(session) + if not dataset: + raise ValueError("Pipeline dataset is required") # init application generate entity - use RagPipelineGenerateEntity instead application_generate_entity = RagPipelineGenerateEntity( @@ -446,9 +449,10 @@ class PipelineGenerator(BaseAppGenerator): if args.get("inputs") is None: raise ValueError("inputs is required") - dataset = pipeline.dataset - if not dataset: - raise ValueError("Pipeline dataset is required") + with Session(db.engine) as session: + dataset = pipeline.retrieve_dataset(session) + if not dataset: + raise ValueError("Pipeline dataset is required") # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( diff --git a/api/models/dataset.py b/api/models/dataset.py index 0f38e0ff0a..89da8ac8ce 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -15,7 +15,7 @@ from typing import Any, Optional, cast import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource @@ -1286,9 +1286,8 @@ class Pipeline(Base): # type: ignore[name-defined] updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - @property - def dataset(self): - return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() + def retrieve_dataset(self, session: Session): + return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() class DocumentPipelineExecutionLog(Base): @@ -1308,6 +1307,7 @@ class DocumentPipelineExecutionLog(Base): created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + class PipelineRecommendedPlugin(Base): __tablename__ = "pipeline_recommended_plugins" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) @@ -1318,4 +1318,4 @@ class PipelineRecommendedPlugin(Base): position = db.Column(db.Integer, nullable=False, default=0) active = db.Column(db.Boolean, nullable=False, default=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) \ No newline at end of file + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index b9d1207dab..500bdafbcf 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -352,9 +352,10 @@ class RagPipelineService: knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) # update dataset - dataset = pipeline.dataset - if not dataset: - raise ValueError("Dataset not found") + with Session(db.engine) as session: + dataset = pipeline.retrieve_dataset(session=session) + if not dataset: + raise ValueError("Dataset not found") DatasetService.update_rag_pipeline_dataset_settings( session=session, dataset=dataset, @@ -1110,9 +1111,10 @@ class RagPipelineService: workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") - dataset = pipeline.dataset - if not dataset: - raise ValueError("Dataset not found") + with Session(db.engine) as session: + dataset = pipeline.retrieve_dataset(session=session) + if not dataset: + raise ValueError("Dataset not found") # check template name is exist template_name = args.get("name") @@ -1136,7 +1138,9 @@ class RagPipelineService: from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService - dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) pipeline_customized_template = PipelineCustomizedTemplate( name=args.get("name"), diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index ce4ccd96a7..bfbb92a160 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -30,7 +30,6 @@ from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.tool.entities import ToolNodeData -from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import variable_factory from models import Account @@ -235,10 +234,7 @@ class RagPipelineDslService: status=ImportStatus.FAILED, error="Pipeline not found", ) - dataset = pipeline.dataset - if dataset: - self._session.merge(dataset) - dataset_name = dataset.name + dataset = pipeline.retrieve_dataset(session=self._session) # If major version mismatch, store import info in Redis if status == ImportStatus.PENDING: @@ -300,7 +296,7 @@ class RagPipelineDslService: ): raise ValueError("Chunk structure is not compatible with the published pipeline") if not dataset: - datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() + datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() names = [dataset.name for dataset in datasets] generate_name = generate_incremental_name(names, name) dataset = Dataset( @@ -321,7 +317,7 @@ class RagPipelineDslService: ) if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) + self._session.query(DatasetCollectionBinding) .filter( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -339,8 +335,8 @@ class RagPipelineDslService: collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type="dataset", ) - db.session.add(dataset_collection_binding) - db.session.commit() + self._session.add(dataset_collection_binding) + self._session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model @@ -454,7 +450,7 @@ class RagPipelineDslService: dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) + self._session.query(DatasetCollectionBinding) .filter( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -472,8 +468,8 @@ class RagPipelineDslService: collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type="dataset", ) - db.session.add(dataset_collection_binding) - db.session.commit() + self._session.add(dataset_collection_binding) + self._session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model @@ -538,18 +534,10 @@ class RagPipelineDslService: account: Account, dependencies: Optional[list[PluginDependency]] = None, ) -> Pipeline: + """Create a new app or update an existing one.""" if not account.current_tenant_id: raise ValueError("Tenant id is required") - """Create a new app or update an existing one.""" pipeline_data = data.get("rag_pipeline", {}) - # Set icon type - icon_type_value = pipeline_data.get("icon_type") - if icon_type_value in ["emoji", "link"]: - icon_type = icon_type_value - else: - icon_type = "emoji" - icon = str(pipeline_data.get("icon", "")) - # Initialize pipeline based on mode workflow_data = data.get("workflow") if not workflow_data or not isinstance(workflow_data, dict): @@ -609,7 +597,7 @@ class RagPipelineDslService: CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) workflow = ( - db.session.query(Workflow) + self._session.query(Workflow) .filter( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, @@ -632,8 +620,8 @@ class RagPipelineDslService: conversation_variables=conversation_variables, rag_pipeline_variables=rag_pipeline_variables_list, ) - db.session.add(workflow) - db.session.flush() + self._session.add(workflow) + self._session.flush() pipeline.workflow_id = workflow.id else: workflow.graph = json.dumps(graph) @@ -643,19 +631,18 @@ class RagPipelineDslService: workflow.conversation_variables = conversation_variables workflow.rag_pipeline_variables = rag_pipeline_variables_list # commit db session changes - db.session.commit() + self._session.commit() return pipeline - @classmethod - def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str: + def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str: """ Export pipeline :param pipeline: Pipeline instance :param include_secret: Whether include secret variable :return: """ - dataset = pipeline.dataset + dataset = pipeline.retrieve_dataset(session=self._session) if not dataset: raise ValueError("Missing dataset for rag pipeline") icon_info = dataset.icon_info @@ -672,12 +659,11 @@ class RagPipelineDslService: }, } - cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) + self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) return yaml.dump(export_data, allow_unicode=True) # type: ignore - @classmethod - def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: + def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: """ Append workflow export data :param export_data: export data @@ -685,7 +671,7 @@ class RagPipelineDslService: """ workflow = ( - db.session.query(Workflow) + self._session.query(Workflow) .filter( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, @@ -701,11 +687,11 @@ class RagPipelineDslService: if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ - cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) + self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) for dataset_id in dataset_ids ] export_data["workflow"] = workflow_dict - dependencies = cls._extract_dependencies_from_workflow(workflow) + dependencies = self._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ jsonable_encoder(d.model_dump()) for d in DependenciesAnalysisService.generate_dependencies( @@ -713,19 +699,17 @@ class RagPipelineDslService: ) ] - @classmethod - def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: + def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]: """ Extract dependencies from workflow :param workflow: Workflow instance :return: dependencies list format like ["langgenius/google"] """ graph = workflow.graph_dict - dependencies = cls._extract_dependencies_from_workflow_graph(graph) + dependencies = self._extract_dependencies_from_workflow_graph(graph) return dependencies - @classmethod - def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]: + def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]: """ Extract dependencies from workflow graph :param graph: Workflow graph @@ -882,25 +866,22 @@ class RagPipelineDslService: return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) - @staticmethod - def _generate_aes_key(tenant_id: str) -> bytes: + def _generate_aes_key(self, tenant_id: str) -> bytes: """Generate AES key based on tenant_id""" return hashlib.sha256(tenant_id.encode()).digest() - @classmethod - def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: + def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str: """Encrypt dataset_id using AES-CBC mode""" - key = cls._generate_aes_key(tenant_id) + key = self._generate_aes_key(tenant_id) iv = key[:16] cipher = AES.new(key, AES.MODE_CBC, iv) ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) return base64.b64encode(ct_bytes).decode() - @classmethod - def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: + def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None: """AES decryption""" try: - key = cls._generate_aes_key(tenant_id) + key = self._generate_aes_key(tenant_id) iv = key[:16] cipher = AES.new(key, AES.MODE_CBC, iv) pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) @@ -908,39 +889,37 @@ class RagPipelineDslService: except Exception: return None - @staticmethod def create_rag_pipeline_dataset( + self, tenant_id: str, rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists if ( - db.session.query(Dataset) + self._session.query(Dataset) .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) .first() ): raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") else: # generate a random name as Untitled 1 2 3 ... - datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, "Untitled", ) - with Session(db.engine) as session: - rag_pipeline_dsl_service = RagPipelineDslService(session) - account = cast(Account, current_user) - rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( - account=account, - import_mode=ImportMode.YAML_CONTENT.value, - yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=None, - dataset_name=rag_pipeline_dataset_create_entity.name, - icon_info=rag_pipeline_dataset_create_entity.icon_info, - ) + account = cast(Account, current_user) + rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline( + account=account, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=rag_pipeline_dataset_create_entity.yaml_content, + dataset=None, + dataset_name=rag_pipeline_dataset_create_entity.name, + icon_info=rag_pipeline_dataset_create_entity.icon_info, + ) return { "id": rag_pipeline_import_info.id, "dataset_id": rag_pipeline_import_info.dataset_id,