diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 56730cf37a..3b673d6e1d 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -8,6 +8,7 @@ Go admin-api caller. from flask import request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model @@ -87,7 +88,7 @@ class EnterpriseAppDSLExport(Resource): """Export an app's DSL as YAML.""" include_secret = request.args.get("include_secret", "false").lower() == "true" - app_model = db.session.query(App).filter_by(id=app_id).first() + app_model = db.session.get(App, app_id) if not app_model: return {"message": "app not found"}, 404 @@ -104,7 +105,7 @@ def _get_active_account(email: str) -> Account | None: Workspace membership is already validated by the Go admin-api caller. """ - account = db.session.query(Account).filter_by(email=email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if account is None or account.status != AccountStatus.ACTIVE: return None return account diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ff8f40407f..06c746990d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -18,7 +18,7 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.entities.model_entities import ModelFeature from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from sqlalchemy import select +from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner): ) # get how many agent thoughts have been created self.agent_thought_count = ( - db.session.query(MessageAgentThought) - .where( - MessageAgentThought.message_id == self.message.id, + db.session.scalar( + select(func.count()) + .select_from(MessageAgentThought) + .where( + MessageAgentThought.message_id == self.message.id, + ) ) - .count() + or 0 ) db.session.close() diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8de5cb1690..6a07119244 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,7 +1,7 @@ import logging from collections.abc import Sequence -from sqlalchemy import select +from sqlalchemy import select, update from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler: ) child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: - _ = ( - db.session.query(DocumentSegment) + db.session.execute( + update(DocumentSegment) .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + .values(hit_count=DocumentSegment.hit_count + 1) ) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + db.session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) db.session.commit() diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 17345dc203..20125ec6b3 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str): from extensions.ext_database import db from models.account import Tenant - if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): + if not (tenant := db.session.get(Tenant, tenant_id)): raise ValueError(f"Tenant with id {tenant_id} not found") assert tenant.encrypt_public_key is not None encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3d94f1a596..d39630ad95 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from sqlalchemy import select from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -410,8 +411,8 @@ class LLMGenerator: model_config: ModelConfig, ideal_output: str | None, ): - last_run: Message | None = ( - db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + last_run: Message | None = db.session.scalar( + select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1) ) if not last_run: return LLMGenerator.__instruction_modify_common( diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 60d08b26c9..be11d2223c 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() + app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) except Exception: raise ValueError("app not found") diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 250dd91bfd..58190d1089 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,4 @@ -from sqlalchemy import select +from sqlalchemy import delete, select from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -31,7 +31,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() + db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) # insert new labels for label in labels: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 584bae39b9..a58d310313 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -255,11 +255,11 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if builtin_provider is None: @@ -818,13 +818,13 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) - .first() + .limit(1) ) if provider is None: @@ -872,13 +872,13 @@ class ToolManager: get api provider """ provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider_obj: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) - .first() + .limit(1) ) if provider_obj is None: @@ -964,10 +964,10 @@ class ToolManager: @classmethod def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) + workflow_provider: WorkflowToolProvider | None = db.session.scalar( + select(WorkflowToolProvider) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + .limit(1) ) if workflow_provider is None: @@ -981,10 +981,10 @@ class ToolManager: @classmethod def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + api_provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() + .limit(1) ) if api_provider is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 6a77fda7ef..e63435db98 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) document_stmt = select(Document).where( Document.id == segment.document_id, Document.enabled == True, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f3d390ed59..cbd8bdb36c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if self.return_resource: for record in records: segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) dataset_document_stmt = select(DatasetDocument).where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 5862239142..4a5f91cc5d 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -64,18 +64,18 @@ class TestGetActiveAccount: def test_returns_active_account(self, mock_db): mock_account = MagicMock() mock_account.status = "active" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") assert result is mock_account - mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com") + mock_db.session.scalar.assert_called_once() @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() mock_account.status = "banned" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -83,7 +83,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_nonexistent_email(self, mock_db): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None result = _get_active_account("missing@example.com") @@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): mock_app = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" unwrapped = inspect.unwrap(api_instance.get) @@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): mock_app = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app mock_dsl_cls.export_dsl.return_value = "yaml-data" unwrapped = inspect.unwrap(api_instance.get) @@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.get.return_value = None unwrapped = inspect.unwrap(api_instance.get) with app.test_request_context("?include_secret=false"): diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index 683cc0e36f..db4b293b16 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool: class TestBaseAgentRunnerInit: def test_init_sets_stream_tool_call_and_files(self, mocker): session = mocker.MagicMock() - session.query.return_value.where.return_value.count.return_value = 2 + session.scalar.return_value = 2 mocker.patch.object(module.db, "session", session) mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index b37c4c57a1..8e5670e9be 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -114,13 +114,9 @@ class TestOnToolEnd: document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_non_parent_child_index(self, handler, mocker): @@ -138,13 +134,9 @@ class TestOnToolEnd: "dataset_id": "dataset-1", } - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_empty_documents(self, handler): diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index 5890009742..f3ef7fccd0 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -38,13 +38,13 @@ class TestObfuscatedToken: class TestEncryptToken: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_successful_encryption(self, mock_encrypt, mock_query): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -52,10 +52,10 @@ class TestEncryptToken: assert result == base64.b64encode(b"encrypted_data").decode() mock_encrypt.assert_called_with("test_token", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") def test_tenant_not_found(self, mock_query): """Test error when tenant doesn't exist""" - mock_query.return_value.where.return_value.first.return_value = None + mock_query.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -119,7 +119,7 @@ class TestGetDecryptDecoding: class TestEncryptDecryptIntegration: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): @@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration: class TestSecurity: """Critical security tests for encryption system""" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_cross_tenant_isolation(self, mock_encrypt, mock_query): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -181,12 +181,12 @@ class TestSecurity: with pytest.raises(Exception, match="Decryption error"): decrypt_token("tenant-123", tampered) - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_encryption_randomness(self, mock_encrypt, mock_query): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -205,13 +205,13 @@ class TestEdgeCases: # Test empty string (which is a valid str type) assert obfuscated_token("") == "" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -219,13 +219,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_empty").decode() mock_encrypt.assert_called_with("", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -242,13 +242,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_special").decode() mock_encrypt.assert_called_with(token, "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 2c0a441125..62e714deb6 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -314,8 +314,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None # Mock __instruction_modify_common call via invoke_llm mock_response = MagicMock() @@ -328,12 +328,12 @@ class TestLLMGenerator: assert result == {"modified": "prompt"} def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: last_run = MagicMock() last_run.query = "q" last_run.answer = "a" last_run.error = "e" - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + mock_scalar.return_value = last_run mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' @@ -483,8 +483,8 @@ class TestLLMGenerator: def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): # Testing placeholders replacement via instruction_modify_legacy for convenience - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"ok": true}' @@ -504,8 +504,8 @@ class TestLLMGenerator: assert "current_val" in user_msg_dict["instruction"] def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No braces here" mock_model_instance.invoke_llm.return_value = mock_response @@ -516,8 +516,8 @@ class TestLLMGenerator: assert "Could not find a valid JSON object" in result["error"] def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "[1, 2, 3]" mock_model_instance.invoke_llm.return_value = mock_response @@ -556,8 +556,8 @@ class TestLLMGenerator: ) def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") result = LLMGenerator.instruction_modify_legacy( @@ -566,8 +566,8 @@ class TestLLMGenerator: assert "Failed to generate code" in result["error"] def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = Exception("Random error") result = LLMGenerator.instruction_modify_legacy( @@ -576,8 +576,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No JSON here" diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index c2778f082b..3feb4159ad 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation: PluginAppBackwardsInvocation._get_user("uid") def test_get_app_returns_app(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain app_obj = MagicMock(id="app") - query_chain.first.return_value = app_obj - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj))) mocker.patch("core.plugin.backwards_invocation.app.db", db) assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj def test_get_app_raises_when_missing(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain - query_chain.first.return_value = None - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") def test_get_app_raises_when_query_fails(self, mocker): - db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down")))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 857f4aa178..8c0e7e9419 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): controller = _api_controller("api-1") with patch("core.tools.tool_label_manager.db") as mock_db: - delete_query = mock_db.session.query.return_value.where.return_value - delete_query.delete.return_value = None ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - delete_query.delete.assert_called_once() + mock_db.session.execute.assert_called_once() # only one valid unique label should be inserted. assert mock_db.session.add.call_count == 1 mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 844bc01e29..31b68f0b3f 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): with patch.object(ToolManager, "get_builtin_provider", return_value=controller): with patch("core.helper.credential_utils.check_credential_policy_compliance"): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - builtin_provider - ) + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"api_key": "secret"} cache = Mock() @@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( ) refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"token": "old"} encrypter.encrypt.return_value = {"token": "encrypted"} @@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch( "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller ) as mock_from_db: @@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): encrypter = Mock() encrypter.decrypt.return_value = {"api_key_value": "secret"} @@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): def test_get_api_provider_controller_not_found_raises(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): ToolManager.get_api_provider_controller("tenant-1", "missing") @@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api(): workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + mock_db.session.scalar.side_effect = [workflow_provider, api_provider] assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py index 4ce73272bf..a93624123e 100644 --- a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): ) db_session = Mock() db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] - db_session.query.return_value.filter_by.return_value.first.return_value = dataset + db_session.get.return_value = dataset tool = SingleDatasetRetrieverTool( tenant_id="tenant-1", @@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): ) db_session = Mock() db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] - db_session.query.return_value.filter_by.return_value.first.side_effect = [ + db_session.get.side_effect = [ SimpleNamespace(id="dataset-2", name="Dataset Two"), SimpleNamespace(id="dataset-1", name="Dataset One"), ]