diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f0157db0f9..5ab71d2c20 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,6 +9,7 @@ from collections import Counter from typing import Any, Literal, Optional import sqlalchemy as sa +import yaml from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -46,6 +47,7 @@ from models.dataset import ( ) from models.model import UploadFile from models.provider_ids import ModelProviderID +from models.workflow import Workflow from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -56,6 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, + RetrievalSetting, ) from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -64,6 +67,7 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService +from services.rag_pipeline.rag_pipeline import RagPipelineService from services.tag_service import TagService from services.vector_service import VectorService from tasks.add_document_to_index_task import add_document_to_index_task @@ -523,12 +527,97 @@ class DatasetService: db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() + # update pipeline knowledge base node data + DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) + # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) return dataset + @staticmethod + def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str): + """ + Update pipeline knowledge base node data. + """ + if dataset.runtime_mode != "rag_pipeline": + return + + pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() + if not pipeline: + return + + try: + rag_pipeline_service = RagPipelineService() + published_workflow = rag_pipeline_service.get_published_workflow(pipeline) + draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline) + + # update knowledge nodes + def update_knowledge_nodes(workflow_graph: str) -> str: + """Update knowledge-index nodes in workflow graph.""" + data: dict[str, Any] = json.loads(workflow_graph) + + nodes = data.get("nodes", []) + updated = False + + for node in nodes: + if node.get("data", {}).get("type") == "knowledge-index": + try: + knowledge_index_node_data = node.get("data", {}) + knowledge_index_node_data["embedding_model"] = dataset.embedding_model + knowledge_index_node_data["embedding_model_provider"] = dataset.embedding_model_provider + knowledge_index_node_data["retrieval_model"] = dataset.retrieval_model + knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure + knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue] + knowledge_index_node_data["keyword_number"] = dataset.keyword_number + node["data"] = knowledge_index_node_data + updated = True + except Exception: + logging.exception("Failed to update knowledge node") + continue + + if updated: + data["nodes"] = nodes + return json.dumps(data) + return workflow_graph + + # Update published workflow + if published_workflow: + updated_graph = update_knowledge_nodes(published_workflow.graph) + if updated_graph != published_workflow.graph: + # Create new workflow version + workflow = Workflow.new( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + type=published_workflow.type, + version=str(datetime.datetime.now(datetime.UTC).replace(tzinfo=None)), + graph=updated_graph, + features=published_workflow.features, + created_by=updata_user_id, + environment_variables=published_workflow.environment_variables, + conversation_variables=published_workflow.conversation_variables, + rag_pipeline_variables=published_workflow.rag_pipeline_variables, + marked_name="", + marked_comment="", + ) + db.session.add(workflow) + + # Update draft workflow + if draft_workflow: + updated_graph = update_knowledge_nodes(draft_workflow.graph) + if updated_graph != draft_workflow.graph: + draft_workflow.graph = updated_graph + db.session.add(draft_workflow) + + # Commit all changes in one transaction + db.session.commit() + + except Exception: + logging.exception("Failed to update pipeline knowledge base node data") + db.session.rollback() + raise + @staticmethod def _handle_indexing_technique_change(dataset, data, filtered_data): """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e28aa02593..0b43404b3d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -65,7 +65,6 @@ from models.workflow import ( WorkflowType, ) from repositories.factory import DifyAPIRepositoryFactory -from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, @@ -346,6 +345,8 @@ class RagPipelineService: graph = workflow.graph_dict nodes = graph.get("nodes", []) + from services.dataset_service import DatasetService + for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": knowledge_configuration = node.get("data", {})