mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Merge 7260cb0c4e into 271019006e
This commit is contained in:
commit
5a0a8bd7df
@ -673,24 +673,25 @@ class ProviderManager:
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(new_provider_record)
|
||||
db.session.commit()
|
||||
with session_factory.create_session() as session:
|
||||
session.add(new_provider_record)
|
||||
session.commit()
|
||||
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota.quota_type,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
continue
|
||||
with session_factory.create_session() as session:
|
||||
existed_provider_record = session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
continue
|
||||
|
||||
if not existed_provider_record.is_valid:
|
||||
existed_provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
if not existed_provider_record.is_valid:
|
||||
existed_provider_record.is_valid = True
|
||||
session.commit()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast, final, override
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
@ -39,7 +39,6 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
|
||||
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
@ -231,7 +230,7 @@ def fetch_memory(
|
||||
if not node_data_memory or not conversation_id:
|
||||
return None
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
|
||||
@ -652,3 +652,49 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf
|
||||
mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True")
|
||||
assert list(result["openai"]) == [openai_config]
|
||||
assert list(result["anthropic"]) == [anthropic_config]
|
||||
|
||||
|
||||
def test_init_trial_provider_records_uses_session_factory_not_flask_context() -> None:
|
||||
"""Regression test for issue #35836.
|
||||
|
||||
_init_trial_provider_records must not use db.session (which requires a
|
||||
Flask application context) because it is called from the parallel iteration
|
||||
thread-pool where no Flask context is available.
|
||||
"""
|
||||
from core.entities.provider_entities import ProviderQuotaType
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
session = Mock()
|
||||
session.add = Mock()
|
||||
session.commit = Mock()
|
||||
|
||||
hosting_configuration = SimpleNamespace(
|
||||
provider_map={
|
||||
"openai": SimpleNamespace(
|
||||
enabled=True,
|
||||
quotas=[
|
||||
SimpleNamespace(quota_type=ProviderQuotaType.TRIAL),
|
||||
],
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.ext_hosting_provider.hosting_configuration", hosting_configuration),
|
||||
patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)),
|
||||
patch("core.provider_manager.ModelProviderID", side_effect=lambda x: SimpleNamespace(provider_name=x)),
|
||||
):
|
||||
# The dict must be a defaultdict(list) as returned by _get_all_providers.
|
||||
from collections import defaultdict
|
||||
|
||||
result = ProviderManager._init_trial_provider_records("tenant-id", defaultdict(list))
|
||||
|
||||
# The session factory must have been called (not db.session).
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
created_record = session.add.call_args.args[0]
|
||||
assert isinstance(created_record, Provider)
|
||||
assert created_record.tenant_id == "tenant-id"
|
||||
assert created_record.provider_type == ProviderType.SYSTEM
|
||||
assert created_record.quota_type == ProviderQuotaType.TRIAL
|
||||
assert created_record in result.get("openai", [])
|
||||
|
||||
@ -89,14 +89,15 @@ class TestFetchMemory:
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_conversation_does_not_exist(self, monkeypatch):
|
||||
"""fetch_memory must use session_factory (not db.engine) so it works
|
||||
outside the Flask application context (e.g. iteration parallel threads).
|
||||
"""
|
||||
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@ -106,9 +107,9 @@ class TestFetchMemory:
|
||||
def scalar(self, _stmt):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
fake_session_factory = SimpleNamespace(create_session=MagicMock(return_value=FakeSession()))
|
||||
monkeypatch.setattr(node_factory, "session_factory", fake_session_factory)
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
@ -118,6 +119,7 @@ class TestFetchMemory:
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake_session_factory.create_session.assert_called_once()
|
||||
|
||||
def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch):
|
||||
conversation = sentinel.conversation
|
||||
@ -128,9 +130,6 @@ class TestFetchMemory:
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@ -141,9 +140,9 @@ class TestFetchMemory:
|
||||
return conversation
|
||||
|
||||
token_buffer_memory = MagicMock(return_value=memory)
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
fake_session_factory = SimpleNamespace(create_session=MagicMock(return_value=FakeSession()))
|
||||
monkeypatch.setattr(node_factory, "session_factory", fake_session_factory)
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
@ -158,6 +157,7 @@ class TestFetchMemory:
|
||||
conversation=conversation,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
fake_session_factory.create_session.assert_called_once()
|
||||
|
||||
|
||||
class TestDifyGraphInitContext:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user