From 76c418c0b7a02d0214acea5522fa775bd6df31c6 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 3 Jul 2025 14:03:06 +0800 Subject: [PATCH] r2 --- api/services/rag_pipeline/rag_pipeline.py | 32 +++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0e1fad600f..f1c1ee3663 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -20,9 +20,12 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import ( BaseDatasourceEvent, @@ -525,6 +528,35 @@ class RagPipelineService: except Exception as e: logger.exception("Error during online document.") yield DatasourceErrorEvent(error=str(e)).model_dump() + case DatasourceProviderType.ONLINE_DRIVE: + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = ( + datasource_runtime.online_drive_browse_files( + user_id=account.id, + request=OnlineDriveBrowseFilesRequest( + bucket=user_inputs.get("bucket"), + prefix=user_inputs.get("prefix"), + max_keys=user_inputs.get("max_keys", 20), + start_after=user_inputs.get("start_after"), + ), + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + for message in online_drive_result: + end_time = time.time() + online_drive_event = DatasourceCompletedEvent( + data=message.result, + time_consuming=round(end_time - start_time, 2), + total=None, + completed=None, + ) + yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (