refactor: human input node decouple db (#32900)

This commit is contained in:
wangxiaolei 2026-03-04 13:18:32 +08:00 committed by GitHub
parent b584434e28
commit e14b09d4db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 69 additions and 58 deletions

View File

@ -58,8 +58,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
dify_graph.nodes.human_input.human_input_node -> extensions.ext_database
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
@ -153,8 +151,6 @@ ignore_imports =
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.human_input.human_input_node -> extensions.ext_database
dify_graph.nodes.human_input.human_input_node -> core.repositories.human_input_repository
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
dify_graph.nodes.llm.node -> models.model

View File

@ -735,7 +735,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)

View File

@ -4,9 +4,10 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
from dify_graph.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryMethod,
@ -198,12 +199,9 @@ class _InvalidTimeoutStatusError(ValueError):
class HumanInputFormRepositoryImpl:
def __init__(
self,
session_factory: sessionmaker | Engine,
*,
tenant_id: str,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._tenant_id = tenant_id
def _delivery_method_to_model(
@ -217,7 +215,7 @@ class HumanInputFormRepositoryImpl:
id=delivery_id,
form_id=form_id,
delivery_method_type=delivery_method.type,
delivery_config_id=delivery_method.id,
delivery_config_id=str(delivery_method.id),
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
@ -343,7 +341,7 @@ class HumanInputFormRepositoryImpl:
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
with self._session_factory(expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
# Generate unique form ID
form_id = str(uuidv7())
start_time = naive_utc_now()
@ -435,7 +433,7 @@ class HumanInputFormRepositoryImpl:
HumanInputForm.node_id == node_id,
HumanInputForm.tenant_id == self._tenant_id,
)
with self._session_factory(expire_on_commit=False) as session:
with session_factory.create_session() as session:
form_model: HumanInputForm | None = session.scalars(form_query).first()
if form_model is None:
return None
@ -448,18 +446,13 @@ class HumanInputFormRepositoryImpl:
class HumanInputFormSubmissionRepository:
"""Repository for fetching and submitting human input forms."""
def __init__(self, session_factory: sessionmaker | Engine):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
with session_factory.create_session() as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
@ -478,7 +471,7 @@ class HumanInputFormSubmissionRepository:
HumanInputFormRecipient.recipient_type == recipient_type,
)
)
with self._session_factory(expire_on_commit=False) as session:
with session_factory.create_session() as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
@ -494,7 +487,7 @@ class HumanInputFormSubmissionRepository:
submission_user_id: str | None,
submission_end_user_id: str | None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
@ -524,7 +517,7 @@ class HumanInputFormSubmissionRepository:
timeout_status: HumanInputFormStatus,
reason: str | None = None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")

View File

@ -19,6 +19,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, SystemVariableKey
@ -34,6 +35,7 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.datasource import DatasourceNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from dify_graph.nodes.llm.entities import ModelConfig
@ -205,6 +207,15 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.HUMAN_INPUT:
return HumanInputNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
form_repository=HumanInputFormRepositoryImpl(tenant_id=self.graph_init_params.tenant_id),
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,

View File

@ -3,7 +3,6 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
@ -21,7 +20,6 @@ from dify_graph.repositories.human_input_form_repository import (
HumanInputFormRepository,
)
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
@ -66,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository | None = None,
form_repository: HumanInputFormRepository,
) -> None:
super().__init__(
id=id,
@ -74,11 +72,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if form_repository is None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self.tenant_id,
)
self._form_repository = form_repository
@classmethod

View File

@ -130,7 +130,7 @@ class HumanInputService:
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory)
self._form_repository = form_repository or HumanInputFormSubmissionRepository()
def get_form_by_token(self, form_token: str) -> Form | None:
record = self._form_repository.get_by_token(form_token)

View File

@ -1015,7 +1015,7 @@ class WorkflowService:
rendered_content: str,
resolved_default_values: Mapping[str, Any],
) -> tuple[str, list[DeliveryTestEmailRecipient]]:
repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id)
repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id)
params = FormCreateParams(
app_id=app_model.id,
workflow_execution_id=None,
@ -1081,6 +1081,7 @@ class WorkflowService:
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id),
)
return node

View File

@ -58,7 +58,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
"""Scan for expired human input forms and resume or end workflows."""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
form_repo = HumanInputFormSubmissionRepository(session_factory)
form_repo = HumanInputFormSubmissionRepository()
service = HumanInputService(session_factory, form_repository=form_repo)
now = naive_utc_now()
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS

View File

@ -100,7 +100,7 @@ class TestHumanInputFormRepositoryImplWithContainers:
member_emails=["member1@example.com", "member2@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id)
params = _build_form_params(
delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])],
)
@ -129,7 +129,7 @@ class TestHumanInputFormRepositoryImplWithContainers:
member_emails=["primary@example.com", "secondary@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id)
params = _build_form_params(
delivery_methods=[
_build_email_delivery(
@ -173,7 +173,7 @@ class TestHumanInputFormRepositoryImplWithContainers:
member_emails=["prefill@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id)
resolved_values = {"greeting": "Hello!"}
params = FormCreateParams(
app_id=str(uuid4()),
@ -210,7 +210,7 @@ class TestHumanInputFormRepositoryImplWithContainers:
member_emails=["ui@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id)
params = FormCreateParams(
app_id=str(uuid4()),
workflow_execution_id=str(uuid4()),

View File

@ -96,8 +96,7 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor
delivery_methods=[delivery_method],
)
engine = db_session_with_containers.get_bind()
repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id)
params = FormCreateParams(
app_id=app_id,
workflow_execution_id=workflow_execution_id,

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import dataclasses
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@ -35,7 +34,7 @@ from models.human_input import (
def _build_repository() -> HumanInputFormRepositoryImpl:
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
return HumanInputFormRepositoryImpl(tenant_id="tenant-id")
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
@ -389,8 +388,21 @@ def _session_factory(session: _FakeSession):
return _factory
def _patch_repo_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
"""Patch repository's global session factory to return our fake session.
The repositories under test now use a global session factory; patch its
create_session method so unit tests don't hit a real database.
"""
monkeypatch.setattr(
"core.repositories.human_input_repository.session_factory.create_session",
_session_factory(session),
raising=True,
)
class TestHumanInputFormRepositoryImplPublicMethods:
def test_get_form_returns_entity_and_recipients(self):
def test_get_form_returns_entity_and_recipients(self, monkeypatch: pytest.MonkeyPatch):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -408,7 +420,8 @@ class TestHumanInputFormRepositoryImplPublicMethods:
access_token="token-123",
)
session = _FakeSession(scalars_results=[form, [recipient]])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
@ -418,13 +431,14 @@ class TestHumanInputFormRepositoryImplPublicMethods:
assert len(entity.recipients) == 1
assert entity.recipients[0].token == "token-123"
def test_get_form_returns_none_when_missing(self):
def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch):
session = _FakeSession(scalars_results=[None])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id")
assert repo.get_form("run-1", "node-1") is None
def test_get_form_returns_unsubmitted_state(self):
def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -436,7 +450,8 @@ class TestHumanInputFormRepositoryImplPublicMethods:
expiration_time=naive_utc_now(),
)
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
@ -445,7 +460,7 @@ class TestHumanInputFormRepositoryImplPublicMethods:
assert entity.selected_action_id is None
assert entity.submitted_data is None
def test_get_form_returns_submission_when_completed(self):
def test_get_form_returns_submission_when_completed(self, monkeypatch: pytest.MonkeyPatch):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -460,7 +475,8 @@ class TestHumanInputFormRepositoryImplPublicMethods:
submitted_at=naive_utc_now(),
)
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
@ -471,7 +487,7 @@ class TestHumanInputFormRepositoryImplPublicMethods:
class TestHumanInputFormSubmissionRepository:
def test_get_by_token_returns_record(self):
def test_get_by_token_returns_record(self, monkeypatch: pytest.MonkeyPatch):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -490,7 +506,8 @@ class TestHumanInputFormSubmissionRepository:
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_token("token-123")
@ -499,7 +516,7 @@ class TestHumanInputFormSubmissionRepository:
assert record.recipient_type == RecipientType.STANDALONE_WEB_APP
assert record.submitted is False
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
def test_get_by_form_id_and_recipient_type_uses_recipient(self, monkeypatch: pytest.MonkeyPatch):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -518,7 +535,8 @@ class TestHumanInputFormSubmissionRepository:
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_form_id_and_recipient_type(
form_id=form.id,
@ -553,7 +571,8 @@ class TestHumanInputFormSubmissionRepository:
forms={form.id: form},
recipients={recipient.id: recipient},
)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
_patch_repo_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record: HumanInputFormRecord = repo.mark_submitted(
form_id=form.id,

View File

@ -47,7 +47,7 @@ class _FakeSessionFactory:
class _FakeFormRepo:
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
def __init__(self, form_map: dict[str, Any] | None = None):
self.calls: list[dict[str, Any]] = []
self._form_map = form_map or {}
@ -149,9 +149,9 @@ def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pyt
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
form_map = {form.id: form for form in forms}
repo = _FakeFormRepo(None, form_map=form_map)
repo = _FakeFormRepo(form_map=form_map)
def _repo_factory(_session_factory):
def _repo_factory():
return repo
service = _FakeService(None)