This commit is contained in:
Brandon 2026-05-09 10:28:00 +08:00 committed by GitHub
commit 5a0a8bd7df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 22 deletions

View File

@ -673,24 +673,25 @@ class ProviderManager:
quota_used=0,
is_valid=True,
)
db.session.add(new_provider_record)
db.session.commit()
with session_factory.create_session() as session:
session.add(new_provider_record)
session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
continue
with session_factory.create_session() as session:
existed_provider_record = session.scalar(stmt)
if not existed_provider_record:
continue
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit()
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
session.commit()
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)

View File

@ -6,11 +6,11 @@ from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast, final, override
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
from core.db.session_factory import session_factory
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@ -39,7 +39,6 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
from extensions.ext_database import db
from graphon.entities.base_node_data import BaseNodeData
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@ -231,7 +230,7 @@ def fetch_memory(
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:

View File

@ -652,3 +652,49 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf
mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True")
assert list(result["openai"]) == [openai_config]
assert list(result["anthropic"]) == [anthropic_config]
def test_init_trial_provider_records_uses_session_factory_not_flask_context() -> None:
"""Regression test for issue #35836.
_init_trial_provider_records must not use db.session (which requires a
Flask application context) because it is called from the parallel iteration
thread-pool where no Flask context is available.
"""
from core.entities.provider_entities import ProviderQuotaType
from models.provider import Provider, ProviderType
session = Mock()
session.add = Mock()
session.commit = Mock()
hosting_configuration = SimpleNamespace(
provider_map={
"openai": SimpleNamespace(
enabled=True,
quotas=[
SimpleNamespace(quota_type=ProviderQuotaType.TRIAL),
],
)
}
)
with (
patch("core.provider_manager.ext_hosting_provider.hosting_configuration", hosting_configuration),
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
patch("core.provider_manager.ModelProviderID", side_effect=lambda x: SimpleNamespace(provider_name=x)),
):
# The dict must be a defaultdict(list) as returned by _get_all_providers.
from collections import defaultdict
result = ProviderManager._init_trial_provider_records("tenant-id", defaultdict(list))
# The session factory must have been called (not db.session).
session.add.assert_called_once()
session.commit.assert_called_once()
created_record = session.add.call_args.args[0]
assert isinstance(created_record, Provider)
assert created_record.tenant_id == "tenant-id"
assert created_record.provider_type == ProviderType.SYSTEM
assert created_record.quota_type == ProviderQuotaType.TRIAL
assert created_record in result.get("openai", [])

View File

@ -89,14 +89,15 @@ class TestFetchMemory:
assert result is None
def test_returns_none_when_conversation_does_not_exist(self, monkeypatch):
"""fetch_memory must use session_factory (not db.engine) so it works
outside the Flask application context (e.g. iteration parallel threads).
"""
class FakeSelect:
def where(self, *_args):
return self
class FakeSession:
def __init__(self, *_args, **_kwargs):
pass
def __enter__(self):
return self
@ -106,9 +107,9 @@ class TestFetchMemory:
def scalar(self, _stmt):
return None
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
fake_session_factory = SimpleNamespace(create_session=MagicMock(return_value=FakeSession()))
monkeypatch.setattr(node_factory, "session_factory", fake_session_factory)
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
monkeypatch.setattr(node_factory, "Session", FakeSession)
result = node_factory.fetch_memory(
conversation_id="conversation-id",
@ -118,6 +119,7 @@ class TestFetchMemory:
)
assert result is None
fake_session_factory.create_session.assert_called_once()
def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch):
conversation = sentinel.conversation
@ -128,9 +130,6 @@ class TestFetchMemory:
return self
class FakeSession:
def __init__(self, *_args, **_kwargs):
pass
def __enter__(self):
return self
@ -141,9 +140,9 @@ class TestFetchMemory:
return conversation
token_buffer_memory = MagicMock(return_value=memory)
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
fake_session_factory = SimpleNamespace(create_session=MagicMock(return_value=FakeSession()))
monkeypatch.setattr(node_factory, "session_factory", fake_session_factory)
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
monkeypatch.setattr(node_factory, "Session", FakeSession)
monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory)
result = node_factory.fetch_memory(
@ -158,6 +157,7 @@ class TestFetchMemory:
conversation=conversation,
model_instance=sentinel.model_instance,
)
fake_session_factory.create_session.assert_called_once()
class TestDifyGraphInitContext: