From 9d401442a37dd689e9714d9c49b49815ce0fe5c5 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 6 May 2026 10:40:19 +0800 Subject: [PATCH] fix: fix Working outside of application context (#35819) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/provider_manager.py | 18 +++++++++--------- api/models/provider.py | 12 +++++++----- .../unit_tests/core/test_provider_manager.py | 16 ++++++---------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 8969825be4..b290ae456e 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -445,7 +445,7 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: provider_name_to_provider_records_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) providers = session.scalars(stmt) for provider in providers: @@ -462,7 +462,7 @@ class ProviderManager: :return: """ provider_name_to_provider_model_records_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) provider_models = session.scalars(stmt) for provider_model in provider_models: @@ -478,7 +478,7 @@ class ProviderManager: :return: """ provider_name_to_preferred_provider_type_records_dict = {} - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) preferred_provider_types = session.scalars(stmt) provider_name_to_preferred_provider_type_records_dict = { @@ -496,7 +496,7 @@ class ProviderManager: :return: """ provider_name_to_provider_model_settings_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) provider_model_settings = session.scalars(stmt) for provider_model_setting in provider_model_settings: @@ -514,7 +514,7 @@ class ProviderManager: :return: """ provider_name_to_provider_model_credentials_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id) provider_model_credentials = session.scalars(stmt) for provider_model_credential in provider_model_credentials: @@ -544,7 +544,7 @@ class ProviderManager: return {} provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) provider_load_balancing_configs = session.scalars(stmt) for provider_load_balancing_config in provider_load_balancing_configs: @@ -578,7 +578,7 @@ class ProviderManager: :param provider_name: provider name :return: """ - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = ( select(ProviderCredential) .where( @@ -608,7 +608,7 @@ class ProviderManager: :param model_type: model type :return: """ - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = ( select(ProviderModelCredential) .where( diff --git a/api/models/provider.py b/api/models/provider.py index 2bb67d605b..8dc3ce4ff6 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -9,11 +9,11 @@ import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column +from core.db.session_factory import session_factory from graphon.model_runtime.entities.model_entities import ModelType from libs.uuid_utils import uuidv7 from .base import TypeBase -from .engine import db from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType from .types import EnumText, LongText, StringUUID @@ -82,7 +82,8 @@ class Provider(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) + with session_factory.create_session() as session: + return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) @property def credential_name(self): @@ -145,9 +146,10 @@ class ProviderModel(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.scalar( - select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) - ) + with session_factory.create_session() as session: + return session.scalar( + select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) + ) @property def credential_name(self): diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index a5a542c94f..02f12fb3b4 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -570,8 +570,7 @@ def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> session.scalars.return_value = [openai_provider, gemini_provider] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = ProviderManager._get_all_providers("tenant-id") @@ -595,8 +594,7 @@ def test_provider_grouping_helpers_group_records_by_provider_name(method_name: s session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = getattr(ProviderManager, method_name)("tenant-id") @@ -611,8 +609,7 @@ def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> session.scalars.return_value = [openai_preference, anthropic_preference] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = ProviderManager._get_all_preferred_model_providers("tenant-id") @@ -626,13 +623,13 @@ def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_ with ( patch("core.provider_manager.redis_client.get", return_value=b"False"), patch("core.provider_manager.FeatureService.get_features") as mock_get_features, - patch("core.provider_manager.Session") as mock_session_cls, + patch("core.provider_manager.session_factory.create_session") as mock_create_session, ): result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") assert result == {} mock_get_features.assert_not_called() - mock_session_cls.assert_not_called() + mock_create_session.assert_not_called() def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: @@ -642,14 +639,13 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf session.scalars.return_value = [openai_config, anthropic_config] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), patch("core.provider_manager.redis_client.get", return_value=None), patch("core.provider_manager.redis_client.setex") as mock_setex, patch( "core.provider_manager.FeatureService.get_features", return_value=SimpleNamespace(model_load_balancing_enabled=True), ), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")