From 739ebf211799e15b217c06012757c493e95583ec Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 17 Jun 2025 18:24:09 +0800 Subject: [PATCH 1/2] feat(datasource): change datasource result type to event-stream --- .../rag_pipeline/rag_pipeline_workflow.py | 4 +-- api/core/rag/entities/event.py | 30 ++++++++++++++++++ .../workflow/graph_engine/entities/event.py | 7 ----- api/services/rag_pipeline/rag_pipeline.py | 31 +++++++++++-------- 4 files changed, 50 insertions(+), 22 deletions(-) create mode 100644 api/core/rag/entities/event.py diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 616803247c..7909bb9609 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -406,10 +406,10 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): args = parser.parse_args() inputs = args.get("inputs") - if inputs == None: + if inputs is None: raise ValueError("missing inputs") datasource_type = args.get("datasource_type") - if datasource_type == None: + if datasource_type is None: raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py new file mode 100644 index 0000000000..8e644fcf85 --- /dev/null +++ b/api/core/rag/entities/event.py @@ -0,0 +1,30 @@ +from collections.abc import Mapping +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class DatasourceStreamEvent(Enum): + """ + Datasource Stream event + """ + PROCESSING = "processing" + COMPLETED = "completed" + + +class BaseDatasourceEvent(BaseModel): + pass + +class DatasourceCompletedEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.COMPLETED.value + data: Mapping[str,Any] | list = Field(..., description="result") + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") + time_consuming: Optional[float] = Field(..., description="time consuming") + +class DatasourceProcessingEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.PROCESSING.value + total: Optional[int] = Field(..., description="total") + completed: Optional[int] = Field(..., description="completed") + diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index fbf591eb8f..063216dd49 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -275,10 +275,3 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent -class DatasourceRunEvent(BaseModel): - status: str = Field(..., description="status") - data: Mapping[str,Any] | list = Field(..., description="result") - total: Optional[int] = Field(..., description="total") - completed: Optional[int] = Field(..., description="completed") - time_consuming: Optional[float] = Field(..., description="time consuming") - diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a5f2135100..ccb920238d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -21,6 +21,7 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin +from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -30,7 +31,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent +from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -486,7 +487,7 @@ class RagPipelineService: def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> Generator[str, None, None]: + ) -> Generator[BaseDatasourceEvent, None, None]: """ Run published workflow datasource """ @@ -542,12 +543,11 @@ class RagPipelineService: start_time = time.time() for message in online_document_result: end_time = time.time() - online_document_event = DatasourceRunEvent( - status="completed", + online_document_event = DatasourceCompletedEvent( data=message.result, time_consuming=round(end_time - start_time, 2) ) - yield json.dumps(online_document_event.model_dump()) + yield online_document_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) @@ -559,14 +559,19 @@ class RagPipelineService: start_time = time.time() for message in website_crawl_result: end_time = time.time() - crawl_event = DatasourceRunEvent( - status=message.result.status, - data=message.result.web_info_list, - total=message.result.total, - completed=message.result.completed, - time_consuming = round(end_time - start_time, 2) - ) - yield json.dumps(crawl_event.model_dump()) + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2) + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") From cf66d111ba9b4d9946607429a9198909cdcc1a53 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 17 Jun 2025 18:29:02 +0800 Subject: [PATCH 2/2] feat(datasource): change datasource result type to event-stream --- api/core/rag/entities/event.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 8e644fcf85..4921c94557 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -9,8 +9,8 @@ class DatasourceStreamEvent(Enum): """ Datasource Stream event """ - PROCESSING = "processing" - COMPLETED = "completed" + PROCESSING = "datasource_processing" + COMPLETED = "datasource_completed" class BaseDatasourceEvent(BaseModel):