diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 45b8c0624a..e3f50bd91e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -498,6 +498,9 @@ class RagPipelineDraftNodeRunApi(Resource): pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user ) + if workflow_node_execution is None: + raise ValueError("Workflow node execution not found") + return workflow_node_execution diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 649e881848..986acc7896 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -116,9 +116,9 @@ workflow_run_node_execution_fields = { "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "finished_at": TimestampField, - "inputs_truncated": fields.Boolean, - "outputs_truncated": fields.Boolean, - "process_data_truncated": fields.Boolean, + # "inputs_truncated": fields.Boolean, + # "outputs_truncated": fields.Boolean, + # "process_data_truncated": fields.Boolean, } workflow_run_node_execution_list_fields = { diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 3e361cab10..9b0273e67e 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -10,7 +10,7 @@ from uuid import uuid4 from flask_login import current_user from sqlalchemy import func, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -33,6 +33,7 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) +from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.variable_pool import VariablePool @@ -63,6 +64,7 @@ from models.workflow import ( WorkflowRun, WorkflowType, ) +from repositories.factory import DifyAPIRepositoryFactory from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService from services.entities.knowledge_entities.rag_pipeline_entities import ( @@ -78,6 +80,16 @@ logger = logging.getLogger(__name__) class RagPipelineService: + + + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize RagPipelineService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + @classmethod def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: if type == "built-in": @@ -390,7 +402,7 @@ class RagPipelineService: def run_draft_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account - ) -> WorkflowNodeExecutionModel: + ) -> WorkflowNodeExecutionModel | None: """ Run draft workflow node """ @@ -435,7 +447,8 @@ class RagPipelineService: workflow_node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = SQLAlchemyWorkflowNodeExecutionRepository( + + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=pipeline.id, @@ -444,16 +457,17 @@ class RagPipelineService: repository.save(workflow_node_execution) # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution) + workflow_node_execution_db_model = self._node_execution_service_repo.get_execution_by_id(workflow_node_execution.id) with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( session=session, app_id=pipeline.id, - node_id=workflow_node_execution_db_model.node_id, - node_type=NodeType(workflow_node_execution_db_model.node_type), + node_id=workflow_node_execution.node_id, + node_type=NodeType(workflow_node_execution.node_type), enclosing_node_id=enclosing_node_id, node_execution_id=workflow_node_execution.id, + user=account, ) draft_var_saver.save( process_data=workflow_node_execution.process_data,