mirror of
https://github.com/langgenius/dify.git
synced 2026-05-08 20:08:36 +08:00
fix: fix Working outside of application context (#35819)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
29388b2a89
commit
9d401442a3
@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@ -445,7 +445,7 @@ class ProviderManager:
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
@ -462,7 +462,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
@ -478,7 +478,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
@ -496,7 +496,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
@ -514,7 +514,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
@ -544,7 +544,7 @@ class ProviderManager:
|
||||
return {}
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
@ -578,7 +578,7 @@ class ProviderManager:
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
@ -608,7 +608,7 @@ class ProviderManager:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
|
||||
@ -9,11 +9,11 @@ import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
@ -82,7 +82,8 @@ class Provider(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
@ -145,9 +146,10 @@ class ProviderModel(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
|
||||
@ -570,8 +570,7 @@ def test_get_all_providers_normalizes_provider_names_with_model_provider_id() ->
|
||||
session.scalars.return_value = [openai_provider, gemini_provider]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = ProviderManager._get_all_providers("tenant-id")
|
||||
|
||||
@ -595,8 +594,7 @@ def test_provider_grouping_helpers_group_records_by_provider_name(method_name: s
|
||||
session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = getattr(ProviderManager, method_name)("tenant-id")
|
||||
|
||||
@ -611,8 +609,7 @@ def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() ->
|
||||
session.scalars.return_value = [openai_preference, anthropic_preference]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = ProviderManager._get_all_preferred_model_providers("tenant-id")
|
||||
|
||||
@ -626,13 +623,13 @@ def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_
|
||||
with (
|
||||
patch("core.provider_manager.redis_client.get", return_value=b"False"),
|
||||
patch("core.provider_manager.FeatureService.get_features") as mock_get_features,
|
||||
patch("core.provider_manager.Session") as mock_session_cls,
|
||||
patch("core.provider_manager.session_factory.create_session") as mock_create_session,
|
||||
):
|
||||
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
|
||||
|
||||
assert result == {}
|
||||
mock_get_features.assert_not_called()
|
||||
mock_session_cls.assert_not_called()
|
||||
mock_create_session.assert_not_called()
|
||||
|
||||
|
||||
def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None:
|
||||
@ -642,14 +639,13 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf
|
||||
session.scalars.return_value = [openai_config, anthropic_config]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.redis_client.get", return_value=None),
|
||||
patch("core.provider_manager.redis_client.setex") as mock_setex,
|
||||
patch(
|
||||
"core.provider_manager.FeatureService.get_features",
|
||||
return_value=SimpleNamespace(model_load_balancing_enabled=True),
|
||||
),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user