diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 42592c6c9a..35d912bfcc 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1051,11 +1051,12 @@ class DocumentPipelineExecutionLogApi(DocumentResource): .first() ) if not log: - return {"datasource_info": None, - "datasource_type": None, - "input_data": None, - "datasource_node_id": None, - }, 200 + return { + "datasource_info": None, + "datasource_type": None, + "input_data": None, + "datasource_node_id": None, + }, 200 return { "datasource_info": json.loads(log.datasource_info), "datasource_type": log.datasource_type, @@ -1086,5 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets//retry") api.add_resource(DocumentRenameApi, "/datasets//documents//rename") api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") -api.add_resource(DocumentPipelineExecutionLogApi, - "/datasets//documents//pipeline-execution-log") +api.add_resource( + DocumentPipelineExecutionLogApi, "/datasets//documents//pipeline-execution-log" +) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 7f7b6a7867..124d45f513 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -96,7 +96,7 @@ class DatasourceAuth(Resource): parser = reqparse.RequestParser() parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 32b5f68364..bb02c659b8 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -48,7 +48,8 @@ class DataSourceContentPreviewApi(Resource): ) return preview_content, 200 + api.add_resource( DataSourceContentPreviewApi, - "/rag/pipelines//workflows/published/datasource/nodes//preview" + "/rag/pipelines//workflows/published/datasource/nodes//preview", ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index fe7c75ce96..cbb382beb3 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,4 +1,3 @@ -from ast import Str from collections.abc import Sequence from enum import Enum, StrEnum from typing import Any, Literal, Optional @@ -128,14 +127,17 @@ class VariableEntity(BaseModel): def convert_none_options(cls, v: Any) -> Sequence[str]: return v or [] + class RagPipelineVariableEntity(VariableEntity): """ Rag Pipeline Variable Entity. """ + tooltips: Optional[str] = None placeholder: Optional[str] = None belong_to_node_id: str + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index b2530ec422..1c63874ee3 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,5 +1,3 @@ -from typing import Any - from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index f410457bc6..b83fc1800f 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -13,6 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig): """ Pipeline Config Entity. """ + rag_pipeline_variables: list[RagPipelineVariableEntity] = [] pass diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 402fd92358..52afb78ee5 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -47,6 +47,7 @@ class PipelineRunner(WorkflowBasedAppRunner): def _get_app_id(self) -> str: return self.application_generate_entity.app_config.app_id + def run(self) -> None: """ Run application @@ -114,9 +115,9 @@ class PipelineRunner(WorkflowBasedAppRunner): for v in workflow.rag_pipeline_variables: rag_pipeline_variable = RAGPipelineVariable(**v) if ( - (rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared") - and rag_pipeline_variable.variable in inputs - ): + rag_pipeline_variable.belong_to_node_id + in (self.application_generate_entity.start_node_id, "shared") + ) and rag_pipeline_variable.variable in inputs: rag_pipeline_variables.append( RAGPipelineVariableInput( variable=rag_pipeline_variable, diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 37f194e0af..3a68f45f61 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -10,8 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment from core.variables.variables import RAGPipelineVariableInput -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \ - SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from core.workflow.enums import SystemVariableKey from factories import variable_factory diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 77eba1a7ce..01f6f51648 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -462,6 +462,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): inputs=parameters_for_log, ) ) + @classmethod def version(cls) -> str: return "1" diff --git a/api/models/workflow.py b/api/models/workflow.py index 3c87903bb3..638885be8d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -323,13 +323,11 @@ class Workflow(Base): return variables def rag_pipeline_user_input_form(self) -> list: - # get user_input_form from start node variables: list[Any] = self.rag_pipeline_variables return variables - @property def unique_hash(self) -> str: """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4a7620bd15..f7941fa49f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -344,10 +344,10 @@ class DatasetService: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") - # check if dataset name is exists + # check if dataset name is exists if ( db.session.query(Dataset) - .filter( + .filter( Dataset.id != dataset_id, Dataset.name == data.get("name", dataset.name), Dataset.tenant_id == dataset.tenant_id, @@ -470,7 +470,7 @@ class DatasetService: filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] - # update icon info + # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index bca0081417..228c18b7c2 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -32,14 +32,10 @@ class DatasourceProviderService: :param credentials: """ # check name is exist - datasource_provider = ( - db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, name=name) - .first() - ) + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first() if datasource_provider: raise ValueError("Authorization name is already exists") - + credential_valid = self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 6427c526d6..0370826c12 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -20,9 +20,12 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import ( BaseDatasourceEvent, @@ -31,8 +34,9 @@ from core.rag.entities.event import ( DatasourceProcessingEvent, ) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import Variable +from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput, Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -381,6 +385,17 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() + rag_pipeline_variables = [] + if draft_workflow.rag_pipeline_variables: + for v in draft_workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.variable in user_inputs: + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=user_inputs[rag_pipeline_variable.variable], + ) + ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -388,6 +403,12 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, + variable_pool=VariablePool( + user_inputs=user_inputs, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -413,6 +434,17 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() + rag_pipeline_variables = [] + if published_workflow.rag_pipeline_variables: + for v in published_workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.variable in user_inputs: + rag_pipeline_variables.append( + RAGPipelineVariableInput( + variable=rag_pipeline_variable, + value=user_inputs[rag_pipeline_variable.variable], + ) + ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -420,6 +452,12 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, + variable_pool=VariablePool( + user_inputs=user_inputs, + environment_variables=published_workflow.environment_variables, + conversation_variables=published_workflow.conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -511,6 +549,33 @@ class RagPipelineService: except Exception as e: logger.exception("Error during online document.") yield DatasourceErrorEvent(error=str(e)).model_dump() + case DatasourceProviderType.ONLINE_DRIVE: + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files( + user_id=account.id, + request=OnlineDriveBrowseFilesRequest( + bucket=user_inputs.get("bucket"), + prefix=user_inputs.get("prefix"), + max_keys=user_inputs.get("max_keys", 20), + start_after=user_inputs.get("start_after"), + ), + provider_type=datasource_runtime.datasource_provider_type(), + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + for message in online_drive_result: + end_time = time.time() + online_drive_event = DatasourceCompletedEvent( + data=message.result, + time_consuming=round(end_time - start_time, 2), + total=None, + completed=None, + ) + yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( @@ -631,7 +696,7 @@ class RagPipelineService: except Exception as e: logger.exception("Error during get online document content.") raise RuntimeError(str(e)) - #TODO Online Drive + # TODO Online Drive case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") except Exception as e: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 8a73c73a1b..282728153a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -86,8 +86,9 @@ class ToolTransformService: ) else: provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, provider_name=provider.name, - icon=provider.declaration.identity.icon + provider_type=provider.type.value, + provider_name=provider.name, + icon=provider.declaration.identity.icon, ) @classmethod