From 44c2efcfe4e0e25ae6183c06fe45a73170210343 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 3 Jul 2025 18:56:42 +0800 Subject: [PATCH] r2 --- .../variables/manager.py | 7 +++-- .../apps/pipeline/pipeline_config_manager.py | 4 +-- .../app/apps/pipeline/pipeline_generator.py | 28 +++++++++++++------ 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 1c63874ee3..5be87a3bb6 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -22,7 +22,7 @@ class WorkflowVariablesConfigManager: return variables @classmethod - def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]: + def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]: """ Convert workflow start variables to variables @@ -31,8 +31,9 @@ class WorkflowVariablesConfigManager: variables = [] user_input_form = workflow.rag_pipeline_user_input_form() - # variables + # filter variables by start_node_id for variable in user_input_form: - variables.append(RagPipelineVariableEntity.model_validate(variable)) + if variable.get("belong_to_node_id") == start_node_id or variable.get("belong_to_node_id") == "shared": + variables.append(RagPipelineVariableEntity.model_validate(variable)) return variables diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index b83fc1800f..a86cad78dc 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -20,13 +20,13 @@ class PipelineConfig(WorkflowUIBasedAppConfig): class PipelineConfigManager(BaseAppConfigManager): @classmethod - def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig: + def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig: pipeline_config = PipelineConfig( tenant_id=pipeline.tenant_id, app_id=pipeline.id, app_mode=AppMode.RAG_PIPELINE, workflow_id=workflow.id, - rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow), + rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id), ) return pipeline_config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 13acc4ef38..6f0d670100 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -29,6 +29,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db @@ -97,11 +98,6 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: - # convert to app config - pipeline_config = PipelineConfigManager.get_pipeline_config( - pipeline=pipeline, - workflow=workflow, - ) # Add null check for dataset dataset = pipeline.dataset if not dataset: @@ -111,6 +107,12 @@ class PipelineGenerator(BaseAppGenerator): datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) + # convert to app config + pipeline_config = PipelineConfigManager.get_pipeline_config( + pipeline=pipeline, + workflow=workflow, + start_node_id=start_node_id + ) documents = [] if invoke_from == InvokeFrom.PUBLISHED: for datasource_info in datasource_info_list: @@ -308,6 +310,9 @@ class PipelineGenerator(BaseAppGenerator): worker_thread.start() + draft_var_saver_factory = self._get_draft_var_saver_factory( + invoke_from, + ) # return response or stream generator response = self._handle_response( application_generate_entity=application_generate_entity, @@ -317,6 +322,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, + draft_var_saver_factory=draft_var_saver_factory, ) return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) @@ -347,7 +353,9 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, + workflow=workflow, + start_node_id=args.get("start_node_id","shared")) dataset = pipeline.dataset if not dataset: @@ -432,7 +440,9 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") # convert to app config - pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, + workflow=workflow, + start_node_id=args.get("start_node_id","shared")) # init application generate entity application_generate_entity = RagPipelineGenerateEntity( @@ -476,7 +486,7 @@ class PipelineGenerator(BaseAppGenerator): return self._generate( flask_app=current_app._get_current_object(), # type: ignore pipeline=pipeline, - workflow=workflow, + workflow_id=workflow.id, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, @@ -539,6 +549,7 @@ class PipelineGenerator(BaseAppGenerator): user: Union[Account, EndUser], workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, + draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -560,6 +571,7 @@ class PipelineGenerator(BaseAppGenerator): stream=stream, workflow_node_execution_repository=workflow_node_execution_repository, workflow_execution_repository=workflow_execution_repository, + draft_var_saver_factory=draft_var_saver_factory, ) try: