This commit is contained in:
jyong 2025-07-03 18:56:42 +08:00
parent f2960989c1
commit 44c2efcfe4
3 changed files with 26 additions and 13 deletions

View File

@ -22,7 +22,7 @@ class WorkflowVariablesConfigManager:
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
@ -31,8 +31,9 @@ class WorkflowVariablesConfigManager:
variables = []
user_input_form = workflow.rag_pipeline_user_input_form()
# variables
# filter variables by start_node_id
for variable in user_input_form:
variables.append(RagPipelineVariableEntity.model_validate(variable))
if variable.get("belong_to_node_id") == start_node_id or variable.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(variable))
return variables

View File

@ -20,13 +20,13 @@ class PipelineConfig(WorkflowUIBasedAppConfig):
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id),
)
return pipeline_config

View File

@ -29,6 +29,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
@ -97,11 +98,6 @@ class PipelineGenerator(BaseAppGenerator):
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
)
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
@ -111,6 +107,12 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
start_node_id=start_node_id
)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
@ -308,6 +310,9 @@ class PipelineGenerator(BaseAppGenerator):
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
@ -317,6 +322,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -347,7 +353,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
dataset = pipeline.dataset
if not dataset:
@ -432,7 +440,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
@ -476,7 +486,7 @@ class PipelineGenerator(BaseAppGenerator):
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow=workflow,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
@ -539,6 +549,7 @@ class PipelineGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -560,6 +571,7 @@ class PipelineGenerator(BaseAppGenerator):
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)
try: