refactor: core/tools, agent, callback_handler, encrypter, llm_generator, plugin, inner_api (#34205)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo 2026-03-28 11:14:43 +01:00 committed by GitHub
parent 7cc81e9a43
commit 364d7ebc40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 99 additions and 118 deletions

View File

@ -8,6 +8,7 @@ Go admin-api caller.
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
@ -87,7 +88,7 @@ class EnterpriseAppDSLExport(Resource):
"""Export an app's DSL as YAML."""
include_secret = request.args.get("include_secret", "false").lower() == "true"
app_model = db.session.query(App).filter_by(id=app_id).first()
app_model = db.session.get(App, app_id)
if not app_model:
return {"message": "app not found"}, 404
@ -104,7 +105,7 @@ def _get_active_account(email: str) -> Account | None:
Workspace membership is already validated by the Go admin-api caller.
"""
account = db.session.query(Account).filter_by(email=email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if account is None or account.status != AccountStatus.ACTIVE:
return None
return account

View File

@ -18,7 +18,7 @@ from graphon.model_runtime.entities import (
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from sqlalchemy import func, select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner):
)
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.where(
MessageAgentThought.message_id == self.message.id,
db.session.scalar(
select(func.count())
.select_from(MessageAgentThought)
.where(
MessageAgentThought.message_id == self.message.id,
)
)
.count()
or 0
)
db.session.close()

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler:
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
db.session.execute(
update(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
.values(hit_count=DocumentSegment.hit_count + 1)
)
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]]
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.execute(
update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1)
)
db.session.commit()

View File

@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
if not (tenant := db.session.get(Tenant, tenant_id)):
raise ValueError(f"Tenant with id {tenant_id} not found")
assert tenant.encrypt_public_key is not None
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@ -10,6 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from sqlalchemy import select
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
@ -410,8 +411,8 @@ class LLMGenerator:
model_config: ModelConfig,
ideal_output: str | None,
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
last_run: Message | None = db.session.scalar(
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
)
if not last_run:
return LLMGenerator.__instruction_modify_common(

View File

@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
except Exception:
raise ValueError("app not found")

View File

@ -1,4 +1,4 @@
from sqlalchemy import select
from sqlalchemy import delete, select
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -31,7 +31,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
# insert new labels
for label in labels:

View File

@ -255,11 +255,11 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
.limit(1)
)
if builtin_provider is None:
@ -818,13 +818,13 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if provider is None:
@ -872,13 +872,13 @@ class ToolManager:
get api provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
provider_obj: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.first()
.limit(1)
)
if provider_obj is None:
@ -964,10 +964,10 @@ class ToolManager:
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
workflow_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
.limit(1)
)
if workflow_provider is None:
@ -981,10 +981,10 @@ class ToolManager:
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
api_provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
.limit(1)
)
if api_provider is None:

View File

@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset = db.session.get(Dataset, segment.dataset_id)
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,

View File

@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset = db.session.get(Dataset, segment.dataset_id)
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,

View File

@ -64,18 +64,18 @@ class TestGetActiveAccount:
def test_returns_active_account(self, mock_db):
mock_account = MagicMock()
mock_account.status = "active"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
mock_db.session.scalar.return_value = mock_account
result = _get_active_account("user@example.com")
assert result is mock_account
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
mock_db.session.scalar.assert_called_once()
@patch("controllers.inner_api.app.dsl.db")
def test_returns_none_for_inactive_account(self, mock_db):
mock_account = MagicMock()
mock_account.status = "banned"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
mock_db.session.scalar.return_value = mock_account
result = _get_active_account("banned@example.com")
@ -83,7 +83,7 @@ class TestGetActiveAccount:
@patch("controllers.inner_api.app.dsl.db")
def test_returns_none_for_nonexistent_email(self, mock_db):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
result = _get_active_account("missing@example.com")
@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport:
@patch("controllers.inner_api.app.dsl.db")
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
mock_app = MagicMock()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
mock_db.session.get.return_value = mock_app
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
unwrapped = inspect.unwrap(api_instance.get)
@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport:
@patch("controllers.inner_api.app.dsl.db")
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
mock_app = MagicMock()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
mock_db.session.get.return_value = mock_app
mock_dsl_cls.export_dsl.return_value = "yaml-data"
unwrapped = inspect.unwrap(api_instance.get)
@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport:
@patch("controllers.inner_api.app.dsl.db")
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.get.return_value = None
unwrapped = inspect.unwrap(api_instance.get)
with app.test_request_context("?include_secret=false"):

View File

@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool:
class TestBaseAgentRunnerInit:
def test_init_sets_stream_tool_call_and_files(self, mocker):
session = mocker.MagicMock()
session.query.return_value.where.return_value.count.return_value = 2
session.scalar.return_value = 2
mocker.patch.object(module.db, "session", session)
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])

View File

@ -114,13 +114,9 @@ class TestOnToolEnd:
document = mocker.Mock()
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_query.update.assert_called_once()
mock_db.session.execute.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
@ -138,13 +134,9 @@ class TestOnToolEnd:
"dataset_id": "dataset-1",
}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_query.update.assert_called_once()
mock_db.session.execute.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_empty_documents(self, handler):

View File

@ -38,13 +38,13 @@ class TestObfuscatedToken:
class TestEncryptToken:
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_successful_encryption(self, mock_encrypt, mock_query):
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
@ -52,10 +52,10 @@ class TestEncryptToken:
assert result == base64.b64encode(b"encrypted_data").decode()
mock_encrypt.assert_called_with("test_token", "mock_public_key")
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
def test_tenant_not_found(self, mock_query):
"""Test error when tenant doesn't exist"""
mock_query.return_value.where.return_value.first.return_value = None
mock_query.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
@ -119,7 +119,7 @@ class TestGetDecryptDecoding:
class TestEncryptDecryptIntegration:
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
@patch("libs.rsa.decrypt")
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration:
class TestSecurity:
"""Critical security tests for encryption system"""
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
"""Ensure tokens encrypted for one tenant cannot be used by another"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
@ -181,12 +181,12 @@ class TestSecurity:
with pytest.raises(Exception, match="Decryption error"):
decrypt_token("tenant-123", tampered)
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_encryption_randomness(self, mock_encrypt, mock_query):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
@ -205,13 +205,13 @@ class TestEdgeCases:
# Test empty string (which is a valid str type)
assert obfuscated_token("") == ""
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
@ -219,13 +219,13 @@ class TestEdgeCases:
assert result == base64.b64encode(b"encrypted_empty").decode()
mock_encrypt.assert_called_with("", "mock_public_key")
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
@ -242,13 +242,13 @@ class TestEdgeCases:
assert result == base64.b64encode(b"encrypted_special").decode()
mock_encrypt.assert_called_with(token, "mock_public_key")
@patch("models.engine.db.session.query")
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_query.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme

View File

@ -314,8 +314,8 @@ class TestLLMGenerator:
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
# Mock __instruction_modify_common call via invoke_llm
mock_response = MagicMock()
@ -328,12 +328,12 @@ class TestLLMGenerator:
assert result == {"modified": "prompt"}
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
last_run = MagicMock()
last_run.query = "q"
last_run.answer = "a"
last_run.error = "e"
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
mock_scalar.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
@ -483,8 +483,8 @@ class TestLLMGenerator:
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
# Testing placeholders replacement via instruction_modify_legacy for convenience
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"ok": true}'
@ -504,8 +504,8 @@ class TestLLMGenerator:
assert "current_val" in user_msg_dict["instruction"]
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No braces here"
mock_model_instance.invoke_llm.return_value = mock_response
@ -516,8 +516,8 @@ class TestLLMGenerator:
assert "Could not find a valid JSON object" in result["error"]
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
mock_model_instance.invoke_llm.return_value = mock_response
@ -556,8 +556,8 @@ class TestLLMGenerator:
)
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
result = LLMGenerator.instruction_modify_legacy(
@ -566,8 +566,8 @@ class TestLLMGenerator:
assert "Failed to generate code" in result["error"]
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.instruction_modify_legacy(
@ -576,8 +576,8 @@ class TestLLMGenerator:
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No JSON here"

View File

@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation:
PluginAppBackwardsInvocation._get_user("uid")
def test_get_app_returns_app(self, mocker):
query_chain = MagicMock()
query_chain.where.return_value = query_chain
app_obj = MagicMock(id="app")
query_chain.first.return_value = app_obj
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
def test_get_app_raises_when_missing(self, mocker):
query_chain = MagicMock()
query_chain.where.return_value = query_chain
query_chain.first.return_value = None
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
with pytest.raises(ValueError, match="app not found"):
PluginAppBackwardsInvocation._get_app("app", "tenant")
def test_get_app_raises_when_query_fails(self, mocker):
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
with pytest.raises(ValueError, match="app not found"):

View File

@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels():
def test_tool_label_manager_update_tool_labels_db():
controller = _api_controller("api-1")
with patch("core.tools.tool_label_manager.db") as mock_db:
delete_query = mock_db.session.query.return_value.where.return_value
delete_query.delete.return_value = None
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
delete_query.delete.assert_called_once()
mock_db.session.execute.assert_called_once()
# only one valid unique label should be inserted.
assert mock_db.session.add.call_count == 1
mock_db.session.commit.assert_called_once()

View File

@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
builtin_provider
)
mock_db.session.scalar.return_value = builtin_provider
encrypter = Mock()
encrypter.decrypt.return_value = {"api_key": "secret"}
cache = Mock()
@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
)
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
mock_db.session.scalar.return_value = builtin_provider
encrypter = Mock()
encrypter.decrypt.return_value = {"token": "old"}
encrypter.encrypt.return_value = {"token": "encrypted"}
@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials():
privacy_policy="privacy",
custom_disclaimer="disclaimer",
)
db_query = Mock()
db_query.where.return_value.first.return_value = provider
controller = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value = db_query
mock_db.session.scalar.return_value = provider
with patch(
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
) as mock_from_db:
@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
privacy_policy="privacy",
custom_disclaimer="disclaimer",
)
db_query = Mock()
db_query.where.return_value.first.return_value = provider
controller = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value = db_query
mock_db.session.scalar.return_value = provider
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
encrypter = Mock()
encrypter.decrypt.return_value = {"api_key_value": "secret"}
@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
def test_get_api_provider_controller_not_found_raises():
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
ToolManager.get_api_provider_controller("tenant-1", "missing")
@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api():
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
mock_db.session.scalar.side_effect = [workflow_provider, api_provider]
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"

View File

@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
)
db_session = Mock()
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
db_session.get.return_value = dataset
tool = SingleDatasetRetrieverTool(
tenant_id="tenant-1",
@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
)
db_session = Mock()
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
db_session.query.return_value.filter_by.return_value.first.side_effect = [
db_session.get.side_effect = [
SimpleNamespace(id="dataset-2", name="Dataset Two"),
SimpleNamespace(id="dataset-1", name="Dataset One"),
]