diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d00be3a573..898ed416c0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -39,7 +39,7 @@ from libs import helper from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account -from models.dataset import Pipeline +from models.dataset import Document, Pipeline from models.model import EndUser from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 55bfdde009..c386e8f41e 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -67,7 +67,7 @@ class DatasourceNodeRunApi(DatasetApiResource): """Resource for datasource node run.""" @service_api_ns.doc(shortcut="pipeline_datasource_node_run") - @service_api_ns.doc(description="Run a datasource node for a rag pipeline") + @service_api_ns.doc(description="Run a atasource node for a rag pipeline") @service_api_ns.doc( path={ "dataset_id": "Dataset ID", diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d970d7480c..d7641bc123 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -141,7 +141,7 @@ class KnowledgeIndexNode(Node): index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() if original_document_id: segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) + select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value) ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index da67801877..796d3ee3ae 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -4,7 +4,8 @@ from typing import Any, Union from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from models.dataset import Pipeline +from extensions.ext_database import db +from models.dataset import Document, Pipeline from models.model import Account, App, EndUser from models.workflow import Workflow from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -31,6 +32,9 @@ class PipelineGenerateService: """ try: workflow = cls._get_workflow(pipeline, invoke_from) + if args.get("original_document_id"): + # update document status to waiting + cls.update_document_status(args.get("original_document_id", "")) return PipelineGenerator.convert_to_event_stream( PipelineGenerator().generate( pipeline=pipeline, @@ -97,3 +101,15 @@ class PipelineGenerateService: raise ValueError("Workflow not published") return workflow + + @classmethod + def update_document_status(cls, document_id: str): + """ + Update document status to waiting + :param document_id: document id + """ + document = db.session.query(Document).filter(Document.id == document_id).first() + if document: + document.indexing_status = "waiting" + db.session.add(document) + db.session.commit() \ No newline at end of file diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 4f97e0f9bc..2da61c828c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -7,6 +7,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional, Union, cast from uuid import uuid4 +import uuid from flask_login import current_user from sqlalchemy import func, or_, select @@ -14,6 +15,7 @@ from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( @@ -55,14 +57,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import ( # type: ignore - Dataset, - Document, - DocumentPipelineExecutionLog, - Pipeline, - PipelineCustomizedTemplate, - PipelineRecommendedPlugin, -) +from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -1325,11 +1320,8 @@ class RagPipelineService: """ Retry error document """ - document_pipeline_excution_log = ( - db.session.query(DocumentPipelineExecutionLog) - .filter(DocumentPipelineExecutionLog.document_id == document.id) - .first() - ) + document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter( + DocumentPipelineExecutionLog.document_id == document.id).first() if not document_pipeline_excution_log: raise ValueError("Document pipeline execution log not found") pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() @@ -1348,11 +1340,11 @@ class RagPipelineService: "start_node_id": document_pipeline_excution_log.datasource_node_id, "datasource_type": document_pipeline_excution_log.datasource_type, "datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)], + "original_document_id": document.id, }, invoke_from=InvokeFrom.PUBLISHED, streaming=False, call_depth=0, workflow_thread_pool_id=None, is_retry=True, - documents=[document], ) diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index ff7848eea6..f4e9b52778 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -29,7 +29,6 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id) """ start_at = time.perf_counter() - print("sadaddadadaaaadadadadsdsadasdadasdasda") try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: