From 0486aa3445fdd494a8e077651a91130c954e0e76 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 3 Jun 2025 13:30:51 +0800 Subject: [PATCH] r2 --- .../console/datasets/datasets_document.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 4 +- .../knowledge_index/knowledge_index_node.py | 18 +++++-- api/services/dataset_service.py | 52 +++++++++---------- .../rag_pipeline_entities.py | 19 +++---- api/services/rag_pipeline/rag_pipeline.py | 22 ++++---- .../rag_pipeline/rag_pipeline_dsl_service.py | 50 +++++++++--------- 7 files changed, 85 insertions(+), 82 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f7c04102a9..60fa1731ca 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -664,7 +664,7 @@ class DocumentDetailApi(DocumentResource): response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a8d2becb4c..fe91f01af6 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -39,8 +39,6 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline -from models.model import EndUser -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService @@ -542,7 +540,7 @@ class RagPipelineConfigApi(Resource): @login_required @account_initialization_required def get(self, pipeline_id): - + return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index c0db13418f..41a6c6141e 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -12,7 +12,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus from ..base import BaseNode @@ -61,11 +61,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) - outputs = self._get_preview_output(node_data.chunk_structure, chunks) - # retrieve knowledge + # index knowledge try: if is_preview: + outputs = self._get_preview_output(node_data.chunk_structure, chunks) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -116,6 +116,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) + #update document segment status + db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() return { diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49b6e208e4..6d0f8ec6a9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1,3 +1,4 @@ +from calendar import day_abbr import copy import datetime import json @@ -52,7 +53,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( SegmentUpdateArgs, ) from services.entities.knowledge_entities.rag_pipeline_entities import ( - KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) @@ -492,23 +492,23 @@ class DatasetService: if action: deal_dataset_vector_index_task.delay(dataset_id, action) return dataset - + @staticmethod def update_rag_pipeline_dataset_settings(session: Session, - dataset: Dataset, - knowledge_configuration: KnowledgeConfiguration, + dataset: Dataset, + knowledge_configuration: KnowledgeConfiguration, has_published: bool = False): + dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - index_method = knowledge_configuration.index_method - dataset.indexing_technique = index_method.indexing_technique - if index_method == "high_quality": + dataset.indexing_technique = knowledge_configuration.indexing_technique + if knowledge_configuration.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.embedding_model, ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -516,30 +516,30 @@ class DatasetService: embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id - elif index_method == "economy": - dataset.keyword_number = index_method.economy_setting.keyword_number + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() session.add(dataset) else: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: raise ValueError("Chunk structure is not allowed to be updated.") action = None - if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique: + if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.index_method.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == "economy": raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.index_method.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == "high_quality": action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.embedding_model, ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -567,7 +567,7 @@ class DatasetService: plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) # Handle new model provider from request - new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name + new_plugin_model_provider = knowledge_configuration.embedding_model_provider new_plugin_model_provider_str = None if new_plugin_model_provider: new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) @@ -575,16 +575,16 @@ class DatasetService: # Only update embedding model if both values are provided and different from current if ( plugin_model_provider_str != new_plugin_model_provider_str - or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model + or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" model_manager = ModelManager() try: embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + provider=knowledge_configuration.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + model=knowledge_configuration.embedding_model, ) except ProviderTokenNotInitError: # If we can't get the embedding model, skip updating it @@ -608,14 +608,14 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) elif dataset.indexing_technique == "economy": - if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number: - dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() - session.add(dataset) + if dataset.keyword_number != knowledge_configuration.keyword_number: + dataset.keyword_number = knowledge_configuration.keyword_number + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + session.add(dataset) session.commit() if action: deal_dataset_index_update_task.delay(dataset.id, action) - + @staticmethod def delete_dataset(dataset_id, user): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 17416d51fd..778c394d5b 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -105,18 +105,11 @@ class IndexMethod(BaseModel): class KnowledgeConfiguration(BaseModel): """ - Knowledge Configuration. + Knowledge Base Configuration. """ - chunk_structure: str - index_method: IndexMethod - retrieval_setting: RetrievalSetting - - -class KnowledgeBaseUpdateConfiguration(BaseModel): - """ - Knowledge Base Update Configuration. - """ - index_method: IndexMethod - chunk_structure: str - retrieval_setting: RetrievalSetting \ No newline at end of file + indexing_technique: Literal["high_quality", "economy"] + embedding_model_provider: Optional[str] = "" + embedding_model: Optional[str] = "" + keyword_number: Optional[int] = 10 + retrieval_model: RetrievalSetting diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index b3c32a7c78..43451528db 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -74,7 +74,7 @@ class RagPipelineService: result = retrieval_instance.get_pipeline_templates(language) return result - @classmethod + @classmethod def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: """ Get pipeline template detail. @@ -284,7 +284,7 @@ class RagPipelineService: graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables, + environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables, rag_pipeline_variables=draft_workflow.rag_pipeline_variables, marked_name="", @@ -296,8 +296,8 @@ class RagPipelineService: graph = workflow.graph_dict nodes = graph.get("nodes", []) for node in nodes: - if node.get("data", {}).get("type") == "knowledge_index": - knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) + if node.get("data", {}).get("type") == "knowledge-index": + knowledge_configuration = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) # update dataset @@ -306,8 +306,8 @@ class RagPipelineService: raise ValueError("Dataset not found") DatasetService.update_rag_pipeline_dataset_settings( session=session, - dataset=dataset, - knowledge_configuration=knowledge_configuration, + dataset=dataset, + knowledge_configuration=knowledge_configuration, has_published=pipeline.is_published ) # return new workflow @@ -771,14 +771,14 @@ class RagPipelineService: # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, - order_config=order_config, + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, + order_config=order_config, triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) # Convert domain models to database models workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] return workflow_node_executions - + @classmethod def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): """ @@ -792,5 +792,5 @@ class RagPipelineService: workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") - - db.session.commit() \ No newline at end of file + + db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 57e81e6f75..189ba0973f 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -1,10 +1,10 @@ import base64 -from datetime import UTC, datetime import hashlib import json import logging import uuid from collections.abc import Mapping +from datetime import UTC, datetime from enum import StrEnum from typing import Optional, cast from urllib.parse import urlparse @@ -292,20 +292,20 @@ class RagPipelineDslService: "background": icon_background, "url": icon_url, }, - indexing_technique=knowledge_configuration.index_method.indexing_technique, + indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, - retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode="rag_pipeline", chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.index_method.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( DatasetCollectionBinding.provider_name - == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name - == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -314,8 +314,8 @@ class RagPipelineDslService: if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( - provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, - model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + provider_name=knowledge_configuration.embedding_model_provider, + model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type="dataset", ) @@ -324,13 +324,13 @@ class RagPipelineDslService: dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( - knowledge_configuration.index_method.embedding_setting.embedding_model_name + knowledge_configuration.embedding_model ) dataset.embedding_model_provider = ( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name + knowledge_configuration.embedding_model_provider ) - elif knowledge_configuration.index_method.indexing_technique == "economy": - dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() @@ -426,25 +426,25 @@ class RagPipelineDslService: "background": icon_background, "url": icon_url, }, - indexing_technique=knowledge_configuration.index_method.indexing_technique, + indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, - retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), + retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode="rag_pipeline", chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique - dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() + dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = "rag_pipeline" dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.index_method.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( DatasetCollectionBinding.provider_name - == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name - == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -453,8 +453,8 @@ class RagPipelineDslService: if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( - provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, - model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, + provider_name=knowledge_configuration.embedding_model_provider, + model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type="dataset", ) @@ -463,13 +463,13 @@ class RagPipelineDslService: dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = ( - knowledge_configuration.index_method.embedding_setting.embedding_model_name + knowledge_configuration.embedding_model ) dataset.embedding_model_provider = ( - knowledge_configuration.index_method.embedding_setting.embedding_provider_name + knowledge_configuration.embedding_model_provider ) - elif knowledge_configuration.index_method.indexing_technique == "economy": - dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number + elif knowledge_configuration.indexing_technique == "economy": + dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit()