mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 11:37:11 +08:00
refactor: core/tools, agent, callback_handler, encrypter, llm_generator, plugin, inner_api (#34205)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
7cc81e9a43
commit
364d7ebc40
@ -8,6 +8,7 @@ Go admin-api caller.
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
@ -87,7 +88,7 @@ class EnterpriseAppDSLExport(Resource):
|
||||
"""Export an app's DSL as YAML."""
|
||||
include_secret = request.args.get("include_secret", "false").lower() == "true"
|
||||
|
||||
app_model = db.session.query(App).filter_by(id=app_id).first()
|
||||
app_model = db.session.get(App, app_id)
|
||||
if not app_model:
|
||||
return {"message": "app not found"}, 404
|
||||
|
||||
@ -104,7 +105,7 @@ def _get_active_account(email: str) -> Account | None:
|
||||
|
||||
Workspace membership is already validated by the Go admin-api caller.
|
||||
"""
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if account is None or account.status != AccountStatus.ACTIVE:
|
||||
return None
|
||||
return account
|
||||
|
||||
@ -18,7 +18,7 @@ from graphon.model_runtime.entities import (
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
db.session.scalar(
|
||||
select(func.count())
|
||||
.select_from(MessageAgentThought)
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler:
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
.values(hit_count=DocumentSegment.hit_count + 1)
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]]
|
||||
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
db.session.execute(
|
||||
update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
if not (tenant := db.session.get(Tenant, tenant_id)):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
assert tenant.encrypt_public_key is not None
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
|
||||
@ -10,6 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
@ -410,8 +411,8 @@ class LLMGenerator:
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
last_run: Message | None = db.session.scalar(
|
||||
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
|
||||
@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -31,7 +31,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
|
||||
@ -255,11 +255,11 @@ class ToolManager:
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
@ -818,13 +818,13 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@ -872,13 +872,13 @@ class ToolManager:
|
||||
get api provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider_obj: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider_obj is None:
|
||||
@ -964,10 +964,10 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
workflow_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if workflow_provider is None:
|
||||
@ -981,10 +981,10 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
api_provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if api_provider is None:
|
||||
|
||||
@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, segment.dataset_id)
|
||||
document_stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
|
||||
@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, segment.dataset_id)
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
|
||||
@ -64,18 +64,18 @@ class TestGetActiveAccount:
|
||||
def test_returns_active_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "active"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
result = _get_active_account("user@example.com")
|
||||
|
||||
assert result is mock_account
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_inactive_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "banned"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
result = _get_active_account("banned@example.com")
|
||||
|
||||
@ -83,7 +83,7 @@ class TestGetActiveAccount:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_nonexistent_email(self, mock_db):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
result = _get_active_account("missing@example.com")
|
||||
|
||||
@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "yaml-data"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=false"):
|
||||
|
||||
@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool:
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
session.scalar.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
|
||||
@ -114,13 +114,9 @@ class TestOnToolEnd:
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
|
||||
@ -138,13 +134,9 @@ class TestOnToolEnd:
|
||||
"dataset_id": "dataset-1",
|
||||
}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_empty_documents(self, handler):
|
||||
|
||||
@ -38,13 +38,13 @@ class TestObfuscatedToken:
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@ -52,10 +52,10 @@ class TestEncryptToken:
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
mock_query.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@ -119,7 +119,7 @@ class TestGetDecryptDecoding:
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration:
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@ -181,12 +181,12 @@ class TestSecurity:
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@ -205,13 +205,13 @@ class TestEdgeCases:
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@ -219,13 +219,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@ -242,13 +242,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@ -314,8 +314,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
# Mock __instruction_modify_common call via invoke_llm
|
||||
mock_response = MagicMock()
|
||||
@ -328,12 +328,12 @@ class TestLLMGenerator:
|
||||
assert result == {"modified": "prompt"}
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
last_run = MagicMock()
|
||||
last_run.query = "q"
|
||||
last_run.answer = "a"
|
||||
last_run.error = "e"
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||
mock_scalar.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
@ -483,8 +483,8 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||
@ -504,8 +504,8 @@ class TestLLMGenerator:
|
||||
assert "current_val" in user_msg_dict["instruction"]
|
||||
|
||||
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No braces here"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@ -516,8 +516,8 @@ class TestLLMGenerator:
|
||||
assert "Could not find a valid JSON object" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@ -556,8 +556,8 @@ class TestLLMGenerator:
|
||||
)
|
||||
|
||||
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@ -566,8 +566,8 @@ class TestLLMGenerator:
|
||||
assert "Failed to generate code" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@ -576,8 +576,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No JSON here"
|
||||
|
||||
@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation:
|
||||
PluginAppBackwardsInvocation._get_user("uid")
|
||||
|
||||
def test_get_app_returns_app(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
app_obj = MagicMock(id="app")
|
||||
query_chain.first.return_value = app_obj
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
|
||||
|
||||
def test_get_app_raises_when_missing(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
query_chain.first.return_value = None
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_app_raises_when_query_fails(self, mocker):
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels():
|
||||
def test_tool_label_manager_update_tool_labels_db():
|
||||
controller = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
delete_query = mock_db.session.query.return_value.where.return_value
|
||||
delete_query.delete.return_value = None
|
||||
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
|
||||
|
||||
delete_query.delete.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
# only one valid unique label should be inserted.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
builtin_provider
|
||||
)
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key": "secret"}
|
||||
cache = Mock()
|
||||
@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
|
||||
)
|
||||
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"token": "old"}
|
||||
encrypter.encrypt.return_value = {"token": "encrypted"}
|
||||
@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
mock_db.session.scalar.return_value = provider
|
||||
with patch(
|
||||
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
|
||||
) as mock_from_db:
|
||||
@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
mock_db.session.scalar.return_value = provider
|
||||
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key_value": "secret"}
|
||||
@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
|
||||
def test_get_api_provider_controller_not_found_raises():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
|
||||
ToolManager.get_api_provider_controller("tenant-1", "missing")
|
||||
|
||||
@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api():
|
||||
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
|
||||
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
|
||||
mock_db.session.scalar.side_effect = [workflow_provider, api_provider]
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
|
||||
|
||||
@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
|
||||
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
|
||||
db_session.get.return_value = dataset
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
|
||||
db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
db_session.get.side_effect = [
|
||||
SimpleNamespace(id="dataset-2", name="Dataset Two"),
|
||||
SimpleNamespace(id="dataset-1", name="Dataset One"),
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user