diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index b290ae456e..76cf0e8ac0 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -673,24 +673,25 @@ class ProviderManager: quota_used=0, is_valid=True, ) - db.session.add(new_provider_record) - db.session.commit() + with session_factory.create_session() as session: + session.add(new_provider_record) + session.commit() provider_name_to_provider_records_dict[provider_name].append(new_provider_record) except IntegrityError: - db.session.rollback() stmt = select(Provider).where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == quota.quota_type, ) - existed_provider_record = db.session.scalar(stmt) - if not existed_provider_record: - continue + with session_factory.create_session() as session: + existed_provider_record = session.scalar(stmt) + if not existed_provider_record: + continue - if not existed_provider_record.is_valid: - existed_provider_record.is_valid = True - db.session.commit() + if not existed_provider_record.is_valid: + existed_provider_record.is_valid = True + session.commit() provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 895953a3c1..c2c1fa15fd 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -6,11 +6,11 @@ from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override from sqlalchemy import select -from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm.model_access import build_dify_model_access, fetch_model_config +from core.db.session_factory import session_factory from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -39,7 +39,7 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import ( from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer -from extensions.ext_database import db + from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType @@ -231,7 +231,7 @@ def fetch_memory( if not node_data_memory or not conversation_id: return None - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) conversation = session.scalar(stmt) if not conversation: diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 02f12fb3b4..f52ff78ae5 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -652,3 +652,49 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True") assert list(result["openai"]) == [openai_config] assert list(result["anthropic"]) == [anthropic_config] + + +def test_init_trial_provider_records_uses_session_factory_not_flask_context() -> None: + """Regression test for issue #35836. + + _init_trial_provider_records must not use db.session (which requires a + Flask application context) because it is called from the parallel iteration + thread-pool where no Flask context is available. + """ + from core.entities.provider_entities import ProviderQuotaType + from models.provider import Provider, ProviderType + + session = Mock() + session.add = Mock() + session.commit = Mock() + + hosting_configuration = SimpleNamespace( + provider_map={ + "openai": SimpleNamespace( + enabled=True, + quotas=[ + SimpleNamespace(quota_type=ProviderQuotaType.TRIAL), + ], + ) + } + ) + + with ( + patch("core.provider_manager.ext_hosting_provider.hosting_configuration", hosting_configuration), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), + patch("core.provider_manager.ModelProviderID", side_effect=lambda x: SimpleNamespace(provider_name=x)), + ): + # The dict must be a defaultdict(list) as returned by _get_all_providers. + from collections import defaultdict + + result = ProviderManager._init_trial_provider_records("tenant-id", defaultdict(list)) + + # The session factory must have been called (not db.session). + session.add.assert_called_once() + session.commit.assert_called_once() + created_record = session.add.call_args.args[0] + assert isinstance(created_record, Provider) + assert created_record.tenant_id == "tenant-id" + assert created_record.provider_type == ProviderType.SYSTEM + assert created_record.quota_type == ProviderQuotaType.TRIAL + assert created_record in result.get("openai", []) diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 1821f72e0c..c9314cfb12 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -89,14 +89,15 @@ class TestFetchMemory: assert result is None def test_returns_none_when_conversation_does_not_exist(self, monkeypatch): + """fetch_memory must use session_factory (not db.engine) so it works + outside the Flask application context (e.g. iteration parallel threads). + """ + class FakeSelect: def where(self, *_args): return self class FakeSession: - def __init__(self, *_args, **_kwargs): - pass - def __enter__(self): return self @@ -106,9 +107,9 @@ class TestFetchMemory: def scalar(self, _stmt): return None - monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + fake_session_factory = SimpleNamespace(create_session=MagicMock(return_value=FakeSession())) + monkeypatch.setattr(node_factory, "session_factory", fake_session_factory) monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) - monkeypatch.setattr(node_factory, "Session", FakeSession) result = node_factory.fetch_memory( conversation_id="conversation-id", @@ -118,6 +119,7 @@ class TestFetchMemory: ) assert result is None + fake_session_factory.create_session.assert_called_once() def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch): conversation = sentinel.conversation @@ -128,9 +130,6 @@ class TestFetchMemory: return self class FakeSession: - def __init__(self, *_args, **_kwargs): - pass - def __enter__(self): return self @@ -141,9 +140,9 @@ class TestFetchMemory: return conversation token_buffer_memory = MagicMock(return_value=memory) - monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + fake_session_factory = SimpleNamespace(create_session=MagicMock(return_value=FakeSession())) + monkeypatch.setattr(node_factory, "session_factory", fake_session_factory) monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) - monkeypatch.setattr(node_factory, "Session", FakeSession) monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory) result = node_factory.fetch_memory( @@ -158,6 +157,7 @@ class TestFetchMemory: conversation=conversation, model_instance=sentinel.model_instance, ) + fake_session_factory.create_session.assert_called_once() class TestDifyGraphInitContext: