diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py new file mode 100644 index 0000000000..4decf0f627 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -0,0 +1,416 @@ +import logging +from typing import Any, NoReturn + +from flask import Response +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.app.error import ( + DraftWorkflowNotExist, +) +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.variables.segment_group import SegmentGroup +from core.variables.segments import ArrayFileSegment, FileSegment, Segment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.file_factory import build_from_mapping, build_from_mappings +from factories.variable_factory import build_segment_with_type +from libs.login import current_user, login_required +from models import db +from models.dataset import Pipeline +from models.workflow import WorkflowDraftVariable +from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService + +logger = logging.getLogger(__name__) + + +def _convert_values_to_json_serializable_object(value: Segment) -> Any: + if isinstance(value, FileSegment): + return value.value.model_dump() + elif isinstance(value, ArrayFileSegment): + return [i.model_dump() for i in value.value] + elif isinstance(value, SegmentGroup): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + else: + return value.value + + +def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: + value = variable.get_value() + # create a copy of the value to avoid affecting the model cache. + value = value.model_copy(deep=True) + # Refresh the url signature before returning it to client. + if isinstance(value, FileSegment): + file = value.value + file.remote_url = file.generate_url() + elif isinstance(value, ArrayFileSegment): + files = value.value + for file in files: + file.remote_url = file.generate_url() + return _convert_values_to_json_serializable_object(value) + + +def _create_pagination_parser(): + parser = reqparse.RequestParser() + parser.add_argument( + "page", + type=inputs.int_range(1, 100_000), + required=False, + default=1, + location="args", + help="the page of data requested", + ) + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + return parser + + +_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda model: model.get_variable_type()), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + value=fields.Raw(attribute=_serialize_var_value), +) + +_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda _: "env"), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), +} + + +def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: + return var_list.variables + + +_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), + "total": fields.Raw(), +} + +_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), +} + + +def _api_prerequisite(f): + """Common prerequisites for all draft workflow variable APIs. + + It ensures the following conditions are satisfied: + + - Dify has been property setup. + - The request user has logged in and initialized. + - The requested app is a workflow or a chat flow. + - The request user has the edit permission for the app. + """ + + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def wrapper(*args, **kwargs): + if not current_user.is_editor: + raise Forbidden() + return f(*args, **kwargs) + + return wrapper + + +class RagPipelineVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) + def get(self, pipeline: Pipeline): + """ + Get draft workflow + """ + parser = _create_pagination_parser() + args = parser.parse_args() + + # fetch draft workflow by app_model + rag_pipeline_service = RagPipelineService() + workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline) + if not workflow_exist: + raise DraftWorkflowNotExist() + + # fetch draft workflow by app_model + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=pipeline.id, + page=args.page, + limit=args.limit, + ) + + return workflow_vars + + @_api_prerequisite + def delete(self, pipeline: Pipeline): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + draft_var_srv.delete_workflow_variables(pipeline.id) + db.session.commit() + return Response("", 204) + + +def validate_node_id(node_id: str) -> NoReturn | None: + if node_id in [ + CONVERSATION_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, + ]: + # NOTE(QuantumGhost): While we store the system and conversation variables as node variables + # with specific `node_id` in database, we still want to make the API separated. By disallowing + # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`, + # we mitigate the risk that user of the API depending on the implementation detail of the API. + # + # ref: [Hyrum's Law](https://www.hyrumslaw.com/) + + raise InvalidArgumentError( + f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", + ) + return None + + +class RagPipelineNodeVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, pipeline: Pipeline, node_id: str): + validate_node_id(node_id) + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id) + + return node_vars + + @_api_prerequisite + def delete(self, pipeline: Pipeline, node_id: str): + validate_node_id(node_id) + srv = WorkflowDraftVariableService(db.session()) + srv.delete_node_variables(pipeline.id, node_id) + db.session.commit() + return Response("", 204) + + +class RagPipelineVariableApi(Resource): + _PATCH_NAME_FIELD = "name" + _PATCH_VALUE_FIELD = "value" + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def get(self, pipeline: Pipeline, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != pipeline.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + return variable + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def patch(self, pipeline: Pipeline, variable_id: str): + # Request payload for file types: + # + # Local File: + # + # { + # "type": "image", + # "transfer_method": "local_file", + # "url": "", + # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190" + # } + # + # Remote File: + # + # + # { + # "type": "image", + # "transfer_method": "remote_url", + # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", + # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" + # } + + parser = reqparse.RequestParser() + parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") + # Parse 'value' field as-is to maintain its original data structure + parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + args = parser.parse_args(strict=True) + + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != pipeline.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + new_name = args.get(self._PATCH_NAME_FIELD, None) + raw_value = args.get(self._PATCH_VALUE_FIELD, None) + if new_name is None and raw_value is None: + return variable + + new_value = None + if raw_value is not None: + if variable.value_type == SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) + elif variable.value_type == SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) + new_value = build_segment_with_type(variable.value_type, raw_value) + draft_var_srv.update_variable(variable, name=new_name, value=new_value) + db.session.commit() + return variable + + @_api_prerequisite + def delete(self, pipeline: Pipeline, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != pipeline.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + draft_var_srv.delete_variable(variable) + db.session.commit() + return Response("", 204) + + +class RagPipelineVariableResetApi(Resource): + @_api_prerequisite + def put(self, pipeline: Pipeline, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + + rag_pipeline_service = RagPipelineService() + draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if draft_workflow is None: + raise NotFoundError( + f"Draft workflow not found, pipeline_id={pipeline.id}", + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != pipeline.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + resetted = draft_var_srv.reset_variable(draft_workflow, variable) + db.session.commit() + if resetted is None: + return Response("", 204) + else: + return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) + + +def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList: + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + if node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_conversation_variables(pipeline.id) + elif node_id == SYSTEM_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_system_variables(pipeline.id) + else: + draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id) + return draft_vars + + +class RagPipelineSystemVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, pipeline: Pipeline): + return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) + + +class RagPipelineEnvironmentVariableCollectionApi(Resource): + @_api_prerequisite + def get(self, pipeline: Pipeline): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if workflow is None: + raise DraftWorkflowNotExist() + + env_vars = workflow.environment_variables + env_vars_list = [] + for v in env_vars: + env_vars_list.append( + { + "id": v.id, + "type": "env", + "name": v.name, + "description": v.description, + "selector": v.selector, + "value_type": v.value_type.value, + "value": v.value, + # Do not track edited for env vars. + "edited": False, + "visible": True, + "editable": True, + } + ) + + return {"items": env_vars_list} + + +api.add_resource( + RagPipelineVariableCollectionApi, + "/rag/pipelines//workflows/draft/variables", +) +api.add_resource( + RagPipelineNodeVariableCollectionApi, + "/rag/pipelines//workflows/draft/nodes//variables", +) +api.add_resource( + RagPipelineVariableApi, "/rag/pipelines//workflows/draft/variables/" +) +api.add_resource( + RagPipelineVariableResetApi, "/rag/pipelines//workflows/draft/variables//reset" +) +api.add_resource( + RagPipelineSystemVariableCollectionApi, "/rag/pipelines//workflows/draft/system-variables" +) +api.add_resource( + RagPipelineEnvironmentVariableCollectionApi, + "/rag/pipelines//workflows/draft/environment-variables", +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 8bae9dc466..bb7e27a4bc 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -959,6 +959,27 @@ class DatasourceListApi(Resource): return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) +class RagPipelineWorkflowLastRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + @marshal_with(workflow_run_node_execution_fields) + def get(self, pipeline: Pipeline, node_id: str): + rag_pipeline_service = RagPipelineService() + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise NotFound("Workflow not found") + node_exec = rag_pipeline_service.get_node_last_run( + pipeline=pipeline, + workflow=workflow, + node_id=node_id, + ) + if node_exec is None: + raise NotFound("last run not found") + return node_exec + + api.add_resource( DraftRagPipelineApi, "/rag/pipelines//workflows/draft", @@ -1068,3 +1089,7 @@ api.add_resource( DraftRagPipelineFirstStepApi, "/rag/pipelines//workflows/draft/pre-processing/parameters", ) +api.add_resource( + RagPipelineWorkflowLastRunApi, + "/rag/pipelines//workflows/draft/nodes//last-run", +) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 8e98c67f12..2bc89ed99c 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -693,6 +693,13 @@ class PipelineGenerator(BaseAppGenerator): all_files, datasource_info, ) + else: + all_files.append( + { + "key": datasource_info.get("key", ""), + "bucket": datasource_info.get("bucket", None), + } + ) return all_files else: return datasource_info_list diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f1c1ee3663..9d273290bf 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1089,3 +1089,32 @@ class RagPipelineService: ) db.session.add(pipeline_customized_template) db.session.commit() + + def is_workflow_exist(self, pipeline: Pipeline) -> bool: + return ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) + .count() + ) > 0 + + def get_node_last_run( + self, pipeline: Pipeline, workflow: Workflow, node_id: str + ) -> WorkflowNodeExecutionModel | None: + # TODO(QuantumGhost): This query is not fully covered by index. + criteria = ( + WorkflowNodeExecutionModel.tenant_id == pipeline.tenant_id, + WorkflowNodeExecutionModel.app_id == pipeline.id, + WorkflowNodeExecutionModel.workflow_id == workflow.id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + node_exec = ( + db.session.query(WorkflowNodeExecutionModel) + .filter(*criteria) + .order_by(WorkflowNodeExecutionModel.created_at.desc()) + .first() + ) + return node_exec