fix: fix Working outside of application context (#35819)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2026-05-06 10:40:19 +08:00 committed by fatelei
parent 29388b2a89
commit 9d401442a3
No known key found for this signature in database
GPG Key ID: 2F91DA05646F4EED
3 changed files with 22 additions and 24 deletions

View File

@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -445,7 +445,7 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
provider_name_to_provider_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
@ -462,7 +462,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
@ -478,7 +478,7 @@ class ProviderManager:
:return:
"""
provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
@ -496,7 +496,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_settings_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
@ -514,7 +514,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
@ -544,7 +544,7 @@ class ProviderManager:
return {}
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
@ -578,7 +578,7 @@ class ProviderManager:
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderCredential)
.where(
@ -608,7 +608,7 @@ class ProviderManager:
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderModelCredential)
.where(

View File

@ -9,11 +9,11 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from core.db.session_factory import session_factory
from graphon.model_runtime.entities.model_entities import ModelType
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
from .types import EnumText, LongText, StringUUID
@ -82,7 +82,8 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
with session_factory.create_session() as session:
return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -145,9 +146,10 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
with session_factory.create_session() as session:
return session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
def credential_name(self):

View File

@ -570,8 +570,7 @@ def test_get_all_providers_normalizes_provider_names_with_model_provider_id() ->
session.scalars.return_value = [openai_provider, gemini_provider]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
):
result = ProviderManager._get_all_providers("tenant-id")
@ -595,8 +594,7 @@ def test_provider_grouping_helpers_group_records_by_provider_name(method_name: s
session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
):
result = getattr(ProviderManager, method_name)("tenant-id")
@ -611,8 +609,7 @@ def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() ->
session.scalars.return_value = [openai_preference, anthropic_preference]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
):
result = ProviderManager._get_all_preferred_model_providers("tenant-id")
@ -626,13 +623,13 @@ def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_
with (
patch("core.provider_manager.redis_client.get", return_value=b"False"),
patch("core.provider_manager.FeatureService.get_features") as mock_get_features,
patch("core.provider_manager.Session") as mock_session_cls,
patch("core.provider_manager.session_factory.create_session") as mock_create_session,
):
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
assert result == {}
mock_get_features.assert_not_called()
mock_session_cls.assert_not_called()
mock_create_session.assert_not_called()
def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None:
@ -642,14 +639,13 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf
session.scalars.return_value = [openai_config, anthropic_config]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.redis_client.get", return_value=None),
patch("core.provider_manager.redis_client.setex") as mock_setex,
patch(
"core.provider_manager.FeatureService.get_features",
return_value=SimpleNamespace(model_load_balancing_enabled=True),
),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
):
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")