diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 49a6ea7b5f..4bce64e0a1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -123,7 +123,8 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name)) + upload_file = FileService(db.engine).upload_text(text=str(text), + text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -201,7 +202,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): name = args.get("name") if text is None or name is None: raise ValueError("Both text and name must be strings.") - upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name)) + upload_file = FileService(db.engine).upload_text(text=str(text), + text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index fdfceeb148..8751197767 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -48,7 +48,6 @@ from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFro from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode -from services.dataset_service import DocumentService from services.datasource_provider_service import DatasourceProviderService from services.feature_service import FeatureService from services.file_service import FileService @@ -72,6 +71,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], + is_retry: bool = False, ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload @@ -86,6 +86,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[False], call_depth: int, workflow_thread_pool_id: Optional[str], + is_retry: bool = False, ) -> Mapping[str, Any]: ... @overload @@ -100,6 +101,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool, call_depth: int, workflow_thread_pool_id: Optional[str], + is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -113,6 +115,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, + is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset @@ -132,7 +135,8 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents = [] - if invoke_from == InvokeFrom.PUBLISHED: + if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + from services.dataset_service import DocumentService for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) document = self._build_document( @@ -156,7 +160,7 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = None - if invoke_from == InvokeFrom.PUBLISHED: + if invoke_from == InvokeFrom.PUBLISHED and not is_retry: document_id = documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, @@ -246,7 +250,7 @@ class PipelineGenerator(BaseAppGenerator): name = "rag_pipeline_invoke_entities.json" # Convert list to proper JSON string json_text = json.dumps(text) - upload_file = FileService(db.engine).upload_text(json_text, name) + upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id) features = FeatureService.get_features(dataset.tenant_id) if features.billing.subscription.plan == "sandbox": tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 5ab71d2c20..03757fe4a5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -543,24 +543,24 @@ class DatasetService: """ if dataset.runtime_mode != "rag_pipeline": return - + pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() if not pipeline: return - + try: rag_pipeline_service = RagPipelineService() published_workflow = rag_pipeline_service.get_published_workflow(pipeline) draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline) - + # update knowledge nodes def update_knowledge_nodes(workflow_graph: str) -> str: """Update knowledge-index nodes in workflow graph.""" data: dict[str, Any] = json.loads(workflow_graph) - + nodes = data.get("nodes", []) updated = False - + for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": try: @@ -576,12 +576,12 @@ class DatasetService: except Exception: logging.exception("Failed to update knowledge node") continue - + if updated: data["nodes"] = nodes return json.dumps(data) return workflow_graph - + # Update published workflow if published_workflow: updated_graph = update_knowledge_nodes(published_workflow.graph) @@ -602,17 +602,17 @@ class DatasetService: marked_comment="", ) db.session.add(workflow) - + # Update draft workflow if draft_workflow: updated_graph = update_knowledge_nodes(draft_workflow.graph) if updated_graph != draft_workflow.graph: draft_workflow.graph = updated_graph db.session.add(draft_workflow) - + # Commit all changes in one transaction db.session.commit() - + except Exception: logging.exception("Failed to update pipeline knowledge base node data") db.session.rollback() @@ -1360,7 +1360,7 @@ class DocumentService: redis_client.setex(retry_indexing_cache_key, 600, 1) # trigger async task document_ids = [document.id for document in documents] - retry_document_indexing_task.delay(dataset_id, document_ids) + retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id) @staticmethod def sync_website_document(dataset_id: str, document: Document): diff --git a/api/services/file_service.py b/api/services/file_service.py index 894b485cce..f9d4eb5686 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -120,33 +120,31 @@ class FileService: return file_size <= file_size_limit - def upload_text(self, text: str, text_name: str) -> UploadFile: - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" + file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt" # save file to storage storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, storage_type=dify_config.STORAGE_TYPE, key=file_key, name=text_name, size=len(text), extension="txt", mime_type="text/plain", - created_by=current_user.id, + created_by=user_id, created_by_role=CreatorUserRole.ACCOUNT, created_at=naive_utc_now(), used=True, - used_by=current_user.id, + used_by=user_id, used_at=naive_utc_now(), ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0b43404b3d..a9aca31439 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,8 +5,9 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast from uuid import uuid4 +import uuid from flask_login import current_user from sqlalchemy import func, or_, select @@ -14,6 +15,8 @@ from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceMessage, @@ -54,7 +57,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore +from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -1312,3 +1315,35 @@ class RagPipelineService: "installed_recommended_plugins": installed_plugin_list, "uninstalled_recommended_plugins": uninstalled_plugin_list, } + + def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + """ + Retry error document + """ + document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter( + DocumentPipelineExecutionLog.document_id == document.id).first() + if not document_pipeline_excution_log: + raise ValueError("Document pipeline execution log not found") + pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + # convert to app config + workflow = self.get_published_workflow(pipeline) + if not workflow: + raise ValueError("Workflow not found") + PipelineGenerator().generate( + pipeline=pipeline, + workflow=workflow, + user=user, + args={ + "inputs": document_pipeline_excution_log.input_data, + "start_node_id": document_pipeline_excution_log.datasource_node_id, + "datasource_type": document_pipeline_excution_log.datasource_type, + "datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)], + }, + invoke_from=InvokeFrom.PUBLISHED, + streaming=False, + call_depth=0, + workflow_thread_pool_id=None, + is_retry=True, + ) diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index c52218caae..1899f93ff7 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -9,32 +9,44 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService +from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): +def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str): """ Async process document :param dataset_id: :param document_ids: + :param user_id: - Usage: retry_document_indexing_task.delay(dataset_id, document_ids) + Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id) """ start_at = time.perf_counter() + print("sadaddadadaaaadadadadsdsadasdadasdasda") try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) return - tenant_id = dataset.tenant_id + user = db.session.query(Account).where(Account.id == user_id).first() + if not user: + logger.info(click.style(f"User not found: {user_id}", fg="red")) + return + tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first() + if not tenant: + raise ValueError("Tenant not found") + user.current_tenant = tenant + for document_id in document_ids: retry_indexing_cache_key = f"document_{document_id}_is_retried" # check document limit - features = FeatureService.get_features(tenant_id) + features = FeatureService.get_features(tenant.id) try: if features.billing.enabled: vector_space = features.vector_space @@ -84,8 +96,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.add(document) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) + if dataset.runtime_mode == "rag_pipeline": + rag_pipeline_service = RagPipelineService() + rag_pipeline_service.retry_error_document(dataset, document, user) + else: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error"