diff --git a/api/.importlinter b/api/.importlinter index 14c2b30101..0d9af6e065 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -58,8 +58,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - # TODO(QuantumGhost): use DI to avoid depending on global DB. - dify_graph.nodes.human_input.human_input_node -> extensions.ext_database [importlinter:contract:workflow-external-imports] name = Workflow External Imports @@ -153,8 +151,6 @@ ignore_imports = dify_graph.nodes.llm.file_saver -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.nodes.tool.tool_node -> extensions.ext_database - dify_graph.nodes.human_input.human_input_node -> extensions.ext_database - dify_graph.nodes.human_input.human_input_node -> core.repositories.human_input_repository dify_graph.nodes.agent.agent_node -> models dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.node -> models.model diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f57a0d9b3b..fbd5060b8c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -735,7 +735,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, tenant_id=self._workflow_tenant_id, ) form = form_repository.get_form(self._workflow_run_id, node_id) diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index bd9afe36f0..6607a87032 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,9 +4,10 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, selectinload +from core.db.session_factory import session_factory from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryMethod, @@ -198,12 +199,9 @@ class _InvalidTimeoutStatusError(ValueError): class HumanInputFormRepositoryImpl: def __init__( self, - session_factory: sessionmaker | Engine, + *, tenant_id: str, ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory self._tenant_id = tenant_id def _delivery_method_to_model( @@ -217,7 +215,7 @@ class HumanInputFormRepositoryImpl: id=delivery_id, form_id=form_id, delivery_method_type=delivery_method.type, - delivery_config_id=delivery_method.id, + delivery_config_id=str(delivery_method.id), channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] @@ -343,7 +341,7 @@ class HumanInputFormRepositoryImpl: def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): # Generate unique form ID form_id = str(uuidv7()) start_time = naive_utc_now() @@ -435,7 +433,7 @@ class HumanInputFormRepositoryImpl: HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: form_model: HumanInputForm | None = session.scalars(form_query).first() if form_model is None: return None @@ -448,18 +446,13 @@ class HumanInputFormRepositoryImpl: class HumanInputFormSubmissionRepository: """Repository for fetching and submitting human input forms.""" - def __init__(self, session_factory: sessionmaker | Engine): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: query = ( select(HumanInputFormRecipient) .options(selectinload(HumanInputFormRecipient.form)) .where(HumanInputFormRecipient.access_token == form_token) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -478,7 +471,7 @@ class HumanInputFormSubmissionRepository: HumanInputFormRecipient.recipient_type == recipient_type, ) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -494,7 +487,7 @@ class HumanInputFormSubmissionRepository: submission_user_id: str | None, submission_end_user_id: str | None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") @@ -524,7 +517,7 @@ class HumanInputFormSubmissionRepository: timeout_status: HumanInputFormStatus, reason: str | None = None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 22d86748f1..1b4937769e 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -19,6 +19,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.index_processor.index_processor import IndexProcessor from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.summary_index.summary_index import SummaryIndex +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, SystemVariableKey @@ -34,6 +35,7 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.nodes.datasource import DatasourceNode from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config +from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from dify_graph.nodes.llm.entities import ModelConfig @@ -205,6 +207,15 @@ class DifyNodeFactory(NodeFactory): file_manager=self._http_request_file_manager, ) + if node_type == NodeType.HUMAN_INPUT: + return HumanInputNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=self.graph_init_params.tenant_id), + ) + if node_type == NodeType.KNOWLEDGE_INDEX: return KnowledgeIndexNode( id=node_id, diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index f41423f550..e54650898d 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -3,7 +3,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( @@ -21,7 +20,6 @@ from dify_graph.repositories.human_input_form_repository import ( HumanInputFormRepository, ) from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient @@ -66,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository | None = None, + form_repository: HumanInputFormRepository, ) -> None: super().__init__( id=id, @@ -74,11 +72,6 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - if form_repository is None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self.tenant_id, - ) self._form_repository = form_repository @classmethod diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index cfab723fef..2e74c50963 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -130,7 +130,7 @@ class HumanInputService: if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) + self._form_repository = form_repository or HumanInputFormSubmissionRepository() def get_form_by_token(self, form_token: str) -> Form | None: record = self._form_repository.get_by_token(form_token) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3ea38c3535..21bc95136e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1015,7 +1015,7 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) params = FormCreateParams( app_id=app_model.id, workflow_execution_id=None, @@ -1081,6 +1081,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), ) return node diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index 03441683b0..dd3b6a4530 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -58,7 +58,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None: """Scan for expired human input forms and resume or end workflows.""" session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - form_repo = HumanInputFormSubmissionRepository(session_factory) + form_repo = HumanInputFormSubmissionRepository() service = HumanInputService(session_factory, form_repository=form_repo) now = naive_utc_now() global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 4b362d1abe..9d0fad4b12 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -100,7 +100,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,7 +129,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[ _build_email_delivery( @@ -173,7 +173,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( app_id=str(uuid4()), @@ -210,7 +210,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=str(uuid4()), workflow_execution_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index c9bcba6639..0876a39f82 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -96,8 +96,7 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - engine = db_session_with_containers.get_bind() - repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=app_id, workflow_execution_id=workflow_execution_id, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 7bf7c2e5f6..9af4d12664 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock import pytest @@ -35,7 +34,7 @@ from models.human_input import ( def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + return HumanInputFormRepositoryImpl(tenant_id="tenant-id") def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: @@ -389,8 +388,21 @@ def _session_factory(session: _FakeSession): return _factory +def _patch_repo_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + """Patch repository's global session factory to return our fake session. + + The repositories under test now use a global session factory; patch its + create_session method so unit tests don't hit a real database. + """ + monkeypatch.setattr( + "core.repositories.human_input_repository.session_factory.create_session", + _session_factory(session), + raising=True, + ) + + class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): + def test_get_form_returns_entity_and_recipients(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -408,7 +420,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: access_token="token-123", ) session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -418,13 +431,14 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" - def test_get_form_returns_none_when_missing(self): + def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") assert repo.get_form("run-1", "node-1") is None - def test_get_form_returns_unsubmitted_state(self): + def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -436,7 +450,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: expiration_time=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -445,7 +460,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert entity.selected_action_id is None assert entity.submitted_data is None - def test_get_form_returns_submission_when_completed(self): + def test_get_form_returns_submission_when_completed(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -460,7 +475,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: submitted_at=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -471,7 +487,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): + def test_get_by_token_returns_record(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -490,7 +506,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_token("token-123") @@ -499,7 +516,7 @@ class TestHumanInputFormSubmissionRepository: assert record.recipient_type == RecipientType.STANDALONE_WEB_APP assert record.submitted is False - def test_get_by_form_id_and_recipient_type_uses_recipient(self): + def test_get_by_form_id_and_recipient_type_uses_recipient(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -518,7 +535,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_form_id_and_recipient_type( form_id=form.id, @@ -553,7 +571,8 @@ class TestHumanInputFormSubmissionRepository: forms={form.id: form}, recipients={recipient.id: recipient}, ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record: HumanInputFormRecord = repo.mark_submitted( form_id=form.id, diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 6b07f88c41..bd0182a402 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -47,7 +47,7 @@ class _FakeSessionFactory: class _FakeFormRepo: - def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): + def __init__(self, form_map: dict[str, Any] | None = None): self.calls: list[dict[str, Any]] = [] self._form_map = form_map or {} @@ -149,9 +149,9 @@ def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pyt monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) form_map = {form.id: form for form in forms} - repo = _FakeFormRepo(None, form_map=form_map) + repo = _FakeFormRepo(form_map=form_map) - def _repo_factory(_session_factory): + def _repo_factory(): return repo service = _FakeService(None)