diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 23d402f914..93976bd6f5 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,6 +1,5 @@ import logging -import yaml from flask import request from flask_restful import Resource, reqparse from sqlalchemy.orm import Session diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 8d68c80c81..dd65c85cbc 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage +from core.tools.entities.tool_entities import ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -290,40 +290,13 @@ class WebSiteInfo(BaseModel): """ Website info """ - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") + status: Optional[str] = Field(..., description="crawl job status") web_info_list: Optional[list[WebSiteInfoDetail]] = [] + total: Optional[int] = Field(default=0, description="The total number of websites") + completed: Optional[int] = Field(default=0, description="The number of completed websites") - -class GetWebsiteCrawlResponse(BaseModel): +class WebsiteCrawlMessage(BaseModel): """ Get website crawl response """ - - result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) - - -class DatasourceInvokeMessage(ToolInvokeMessage): - """ - Datasource Invoke Message. - """ - - class WebsiteCrawlMessage(BaseModel): - """ - Website crawl message - """ - - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") - web_info_list: Optional[list[WebSiteInfoDetail]] = [] - - class OnlineDocumentMessage(BaseModel): - """ - Online document message - """ - - workspace_id: str = Field(..., description="The workspace id") - workspace_name: str = Field(..., description="The workspace name") - workspace_icon: str = Field(..., description="The workspace icon") - total: int = Field(..., description="The total number of documents") - pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") \ No newline at end of file + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index e8256b3282..d0e442f31a 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin @@ -6,7 +6,7 @@ from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceProviderType, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -31,12 +31,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def _get_website_crawl( + def get_website_crawl( self, user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[WebsiteCrawlMessage, None, None]: manager = PluginDatasourceManager() return manager.get_website_crawl( diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index a65efb750e..0567f1a480 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,4 +1,3 @@ -from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 83b1a5760b..54325a545f 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,12 +1,12 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -94,17 +94,17 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[WebsiteCrawlMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", - DatasourceInvokeMessage, + WebsiteCrawlMessage, data={ "user_id": user_id, "data": { @@ -119,7 +119,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def get_online_document_pages( self, diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index c63d837106..49c8ec1e69 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,7 @@ import datetime import logging -from collections.abc import Mapping import time +from collections.abc import Mapping from typing import Any, cast from sqlalchemy import func diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index df9fea805c..43b68b3b97 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -17,11 +17,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -43,14 +42,14 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser -from models.oauth import DatasourceProvider from models.workflow import ( Workflow, + WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowType, WorkflowNodeExecutionModel, + WorkflowType, ) from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService @@ -468,15 +467,16 @@ class RagPipelineService: case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_results: list[WebsiteCrawlMessage] = [] + for website_message in datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters={"job_id": job_id}, provider_type=datasource_runtime.datasource_provider_type(), - ) + ): + website_crawl_results.append(website_message) return { - "result": [result for result in website_crawl_result.result], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, + "result": [result for result in website_crawl_results.result], + "status": website_crawl_results.result.status, "provider_type": datasource_node_data.get("provider_type"), } case _: @@ -544,14 +544,15 @@ class RagPipelineService: case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_results: list[WebsiteCrawlMessage] = [] + for website_crawl_result in datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), - ) + ): + website_crawl_results.append(website_crawl_result) return { "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], - "job_id": website_crawl_result.result.job_id, "status": website_crawl_result.result.status, "provider_type": datasource_node_data.get("provider_type"), }