From a7b53b33eeb6882558e7459b31585a3158883d02 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sat, 20 Jun 2026 04:52:59 +0900 Subject: [PATCH 01/70] chore: move one db.session (#37656) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/app/agent_app_feature.py | 5 +- api/controllers/console/auth/activate.py | 2 +- .../console/auth/forgot_password.py | 2 +- api/controllers/console/auth/login.py | 3 +- api/controllers/console/auth/oauth.py | 2 +- .../inner_api/workspace/workspace.py | 2 +- api/services/account_service.py | 18 ++-- api/services/agent_app_feature_service.py | 13 +-- .../controllers/console/auth/test_oauth.py | 14 +-- .../controllers/openapi/conftest.py | 3 +- .../services/test_account_service.py | 88 ++++++++++--------- .../console/auth/test_account_activation.py | 68 +++++++------- .../inner_api/workspace/test_workspace.py | 4 +- .../services/test_account_service.py | 32 +++++-- .../test_agent_app_feature_service.py | 5 +- 15 files changed, 144 insertions(+), 117 deletions(-) diff --git a/api/controllers/console/app/agent_app_feature.py b/api/controllers/console/app/agent_app_feature.py index 5d2b77c97f1..d155dae6ac3 100644 --- a/api/controllers/console/app/agent_app_feature.py +++ b/api/controllers/console/app/agent_app_feature.py @@ -29,6 +29,7 @@ from controllers.console.wraps import ( with_current_user, ) from events.app_event import app_model_config_was_updated +from extensions.ext_database import db from libs.helper import dump_response from libs.login import login_required from models import Account @@ -90,9 +91,7 @@ class AgentAppFeatureConfigResource(Resource): args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {}) new_app_model_config = AgentAppFeatureConfigService.update_features( - app_model=app_model, - account=current_user, - config=args.model_dump(exclude_none=True), + app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session ) app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f61bb8f6802..c9142d85ede 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -174,7 +174,7 @@ class ActivateApi(Resource): RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token) if membership_id is None: - TenantService.create_tenant_member(tenant, account, str(role)) + TenantService.create_tenant_member(tenant, account, db.session, role=role) if setup_fields: account.name = setup_fields[0] diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index c34dd1ac859..d82f63c11db 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -202,6 +202,6 @@ class ForgotPasswordResetApi(Resource): and FeatureService.get_system_features().is_allow_create_workspace ): tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db.session, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 6a1b4c6769e..053f313ba53 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -35,6 +35,7 @@ from controllers.console.wraps import ( with_current_user, ) from events.tenant_event import tenant_was_created +from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.helper import timezone as validate_timezone_string from libs.token import ( @@ -299,7 +300,7 @@ class EmailCodeLoginApi(Resource): raise NotAllowedCreateWorkspace() else: new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(new_tenant, account, role="owner") + TenantService.create_tenant_member(new_tenant, account, db.session, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 31649812fe8..78d1583fde9 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -246,7 +246,7 @@ def _generate_account( raise WorkSpaceNotAllowedCreateError() else: new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(new_tenant, account, role="owner") + TenantService.create_tenant_member(new_tenant, account, db.session, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index ef0a46db63a..dd93616e6b1 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -48,7 +48,7 @@ class EnterpriseWorkspace(Resource): return {"message": "owner account not found."}, 404 tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db.session, role="owner") tenant_was_created.send(tenant) diff --git a/api/services/account_service.py b/api/services/account_service.py index a608f544747..21b5f1eedba 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1280,7 +1280,7 @@ class TenantService: tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db.session, role="owner") if dify_config.RBAC_ENABLED: owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER) RBACService.MemberRoles.replace( @@ -1294,14 +1294,16 @@ class TenantService: tenant_was_created.send(tenant) @staticmethod - def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: + def create_tenant_member( + tenant: Tenant, account: Account, session: scoped_session, role: str = "normal" + ) -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountRole.OWNER: if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") - ta = db.session.scalar( + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1310,9 +1312,9 @@ class TenantService: ta.role = TenantAccountRole(role) else: ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role)) - db.session.add(ta) + session.add(ta) - db.session.commit() + session.commit() if dify_config.BILLING_ENABLED: BillingService.clean_billing_info_cache(tenant.id) return ta @@ -1915,7 +1917,7 @@ class RegisterService: ): try: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db.session, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) except Exception: @@ -1970,7 +1972,7 @@ class RegisterService: status=AccountStatus.PENDING, is_setup=True, ) - TenantService.create_tenant_member(tenant, account, tenant_join_role) + TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role) TenantService.switch_tenant(account, tenant.id) requires_setup = True else: @@ -1983,7 +1985,7 @@ class RegisterService: requires_setup = account.status == AccountStatus.PENDING if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED): - TenantService.create_tenant_member(tenant, account, tenant_join_role) + TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role) # Support resend invitation email when the account is pending status if account.status != AccountStatus.PENDING: diff --git a/api/services/agent_app_feature_service.py b/api/services/agent_app_feature_service.py index cc0fe67802d..b8e98653c8e 100644 --- a/api/services/agent_app_feature_service.py +++ b/api/services/agent_app_feature_service.py @@ -13,6 +13,8 @@ from __future__ import annotations from typing import Any, cast +from sqlalchemy.orm import scoped_session + from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager @@ -21,7 +23,6 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppModelConfig, AppModelConfigDict @@ -67,7 +68,9 @@ class AgentAppFeatureConfigService: return cast(AppModelConfigDict, filtered) @classmethod - def update_features(cls, *, app_model: App, account: Account, config: dict[str, Any]) -> AppModelConfig: + def update_features( + cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session + ) -> AppModelConfig: """Persist the presentation features as a new app_model_config version. Returns the new ``AppModelConfig`` row (now referenced by the app); the @@ -82,13 +85,13 @@ class AgentAppFeatureConfigService: updated_by=account.id, ).from_model_config_dict(validated) - db.session.add(new_config) - db.session.flush() + session.add(new_config) + session.flush() app_model.app_model_config_id = new_config.id app_model.updated_by = account.id app_model.updated_at = naive_utc_now() - db.session.commit() + session.commit() return new_config diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 9ef6b903066..d043c0d413a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -655,11 +655,11 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.tenant_was_created") def test_should_create_workspace_for_account_without_tenant( self, - mock_event, - mock_account_service, - mock_feature_service, - mock_tenant_service, - mock_get_account, + mock_event: MagicMock, + mock_account_service: MagicMock, + mock_feature_service: MagicMock, + mock_tenant_service: MagicMock, + mock_get_account: MagicMock, app: Flask, user_info: OAuthUserInfo, mock_account, @@ -678,6 +678,6 @@ class TestAccountGeneration: assert oauth_new_user is False mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace") mock_tenant_service.create_tenant_member.assert_called_once_with( - mock_new_tenant, mock_account, role="owner" + mock_new_tenant, mock_account, ANY, role="owner" ) mock_event.send.assert_called_once_with(mock_new_tenant) diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py b/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py index d961479f55b..5fe0f787524 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py @@ -12,6 +12,7 @@ from flask import Flask from sqlalchemy.orm import Session from controllers.openapi.auth.data import AuthData +from extensions.ext_database import db from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx from models import Account, Tenant from services.account_service import AccountService, TenantService @@ -58,7 +59,7 @@ def add_tenant_for_account(account: Account, *, role: str = "normal", name: str with patch("services.account_service.FeatureService") as mock_feature_service: mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True tenant = TenantService.create_tenant(name=name) - TenantService.create_tenant_member(tenant, account, role=role) + TenantService.create_tenant_member(tenant, account, db.session, role=role) return tenant diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 61dca361f3e..a2f5370cb76 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -1276,7 +1276,7 @@ class TestTenantService: ) # Create tenant member - tenant_member = TenantService.create_tenant_member(tenant, account, role="admin") + tenant_member = TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="admin") assert tenant_member.tenant_id == tenant.id assert tenant_member.account_id == account.id @@ -1317,11 +1317,11 @@ class TestTenantService: ) # Create first owner - TenantService.create_tenant_member(tenant, account1, role="owner") + TenantService.create_tenant_member(tenant, account1, db_session_with_containers, role="owner") # Try to create second owner (should fail) with pytest.raises(Exception, match="Tenant already has an owner"): - TenantService.create_tenant_member(tenant, account2, role="owner") + TenantService.create_tenant_member(tenant, account2, db_session_with_containers, role="owner") def test_create_tenant_member_existing_member( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1349,11 +1349,11 @@ class TestTenantService: ) # Create member with initial role - tenant_member1 = TenantService.create_tenant_member(tenant, account, role="normal") + tenant_member1 = TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") assert tenant_member1.role == "normal" # Update member role - tenant_member2 = TenantService.create_tenant_member(tenant, account, role="editor") + tenant_member2 = TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="editor") assert tenant_member2.tenant_id == tenant_member1.tenant_id assert tenant_member2.account_id == tenant_member1.account_id assert tenant_member2.role == "editor" @@ -1384,8 +1384,8 @@ class TestTenantService: tenant2 = TenantService.create_tenant(name=tenant2_name) # Add account to both tenants - TenantService.create_tenant_member(tenant1, account, role="normal") - TenantService.create_tenant_member(tenant2, account, role="admin") + TenantService.create_tenant_member(tenant1, account, db_session_with_containers, role="normal") + TenantService.create_tenant_member(tenant2, account, db_session_with_containers, role="admin") # Get join tenants join_tenants = TenantService.get_join_tenants(account) @@ -1421,7 +1421,7 @@ class TestTenantService: tenant = TenantService.create_tenant(name=tenant_name) # Add account to tenant and set as current - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="owner") account.current_tenant = tenant db_session_with_containers.commit() @@ -1486,8 +1486,8 @@ class TestTenantService: tenant2 = TenantService.create_tenant(name=tenant2_name) # Add account to both tenants - TenantService.create_tenant_member(tenant1, account, role="owner") - TenantService.create_tenant_member(tenant2, account, role="admin") + TenantService.create_tenant_member(tenant1, account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant2, account, db_session_with_containers, role="admin") # Set initial current tenant account.current_tenant = tenant1 @@ -1588,8 +1588,8 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, admin_account, role="admin") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, admin_account, db_session_with_containers, role="admin") # Check if tenant has owner role from models.account import TenantAccountRole @@ -1648,7 +1648,7 @@ class TestTenantService: ) # Add account to tenant with specific role - TenantService.create_tenant_member(tenant, account, role="editor") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="editor") # Get user role user_role = TenantService.get_user_role(account, tenant) @@ -1690,8 +1690,8 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, member_account, role="normal") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") # Check owner permission to add member (should succeed) TenantService.check_member_permission(tenant, owner_account, member_account, "add") @@ -1723,7 +1723,7 @@ class TestTenantService: ) # Add account to tenant - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="owner") # Try to check permission with invalid action with pytest.raises(Exception, match="Invalid action"): @@ -1755,7 +1755,7 @@ class TestTenantService: ) # Add account to tenant - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="owner") # Try to check permission to operate self with pytest.raises(Exception, match="Cannot operate self"): @@ -1796,8 +1796,8 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, member_account, role="normal") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") app = App( tenant_id=tenant.id, @@ -1876,7 +1876,7 @@ class TestTenantService: ) # Add account to tenant - TenantService.create_tenant_member(tenant, account, role="owner") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="owner") # Try to remove self with pytest.raises(Exception, match="Cannot operate self"): @@ -1917,7 +1917,7 @@ class TestTenantService: ) # Add only owner to tenant - TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") # Try to remove non-member with pytest.raises(Exception, match="Member not in tenant"): @@ -1956,8 +1956,8 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, member_account, role="normal") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") # Update member role TenantService.update_member_role(tenant, member_account, "admin", owner_account) @@ -2005,8 +2005,8 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, member_account, role="admin") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="admin") # Update member role to owner TenantService.update_member_role(tenant, member_account, "owner", owner_account) @@ -2062,8 +2062,8 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, member_account, role="admin") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="admin") # Try to update member role to already assigned role with pytest.raises(Exception, match="The provided role is already assigned to the member"): @@ -2160,7 +2160,7 @@ class TestTenantService: password=password, ) existing_tenant = TenantService.create_tenant(name=existing_tenant_name) - TenantService.create_tenant_member(existing_tenant, account, role="owner") + TenantService.create_tenant_member(existing_tenant, account, db_session_with_containers, role="owner") account.current_tenant = existing_tenant db_session_with_containers.commit() @@ -2243,9 +2243,9 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, admin_account, role="admin") - TenantService.create_tenant_member(tenant, normal_account, role="normal") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member(tenant, admin_account, db_session_with_containers, role="admin") + TenantService.create_tenant_member(tenant, normal_account, db_session_with_containers, role="normal") # Get tenant members members = TenantService.get_tenant_members(tenant) @@ -2309,9 +2309,11 @@ class TestTenantService: ) # Add members with different roles - TenantService.create_tenant_member(tenant, owner_account, role="owner") - TenantService.create_tenant_member(tenant, dataset_operator_account, role="dataset_operator") - TenantService.create_tenant_member(tenant, normal_account, role="normal") + TenantService.create_tenant_member(tenant, owner_account, db_session_with_containers, role="owner") + TenantService.create_tenant_member( + tenant, dataset_operator_account, db_session_with_containers, role="dataset_operator" + ) + TenantService.create_tenant_member(tenant, normal_account, db_session_with_containers, role="normal") # Get dataset operator members dataset_operators = TenantService.get_dataset_operator_members(tenant) @@ -2742,7 +2744,7 @@ class TestRegisterService: interface_language="en-US", password=inviter_password, ) - TenantService.create_tenant_member(tenant, inviter, role="owner") + TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") # Mock the email task with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: @@ -2808,7 +2810,7 @@ class TestRegisterService: interface_language="en-US", password=inviter_password, ) - TenantService.create_tenant_member(tenant, inviter, role="owner") + TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") # Create existing account existing_account = AccountService.create_account( @@ -2877,7 +2879,7 @@ class TestRegisterService: interface_language="en-US", password=inviter_password, ) - TenantService.create_tenant_member(tenant, inviter, role="owner") + TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") # Create existing account with pending status existing_account = AccountService.create_account( @@ -2891,7 +2893,7 @@ class TestRegisterService: db_session_with_containers.commit() # Add existing account to tenant - TenantService.create_tenant_member(tenant, existing_account, role="normal") + TenantService.create_tenant_member(tenant, existing_account, db_session_with_containers, role="normal") # Mock the email task with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: @@ -2967,7 +2969,7 @@ class TestRegisterService: interface_language="en-US", password=inviter_password, ) - TenantService.create_tenant_member(tenant, inviter, role="owner") + TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") # Create existing account with active status existing_account = AccountService.create_account( @@ -2981,7 +2983,7 @@ class TestRegisterService: db_session_with_containers.commit() # Add existing account to tenant - TenantService.create_tenant_member(tenant, existing_account, role="normal") + TenantService.create_tenant_member(tenant, existing_account, db_session_with_containers, role="normal") # Execute invitation (should fail for active member) with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): @@ -3193,7 +3195,7 @@ class TestRegisterService: interface_language="en-US", password=password, ) - TenantService.create_tenant_member(tenant, account, role="normal") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") # Generate a real token token = RegisterService.generate_invite_token(tenant, account) @@ -3313,7 +3315,7 @@ class TestRegisterService: interface_language="en-US", password=password, ) - TenantService.create_tenant_member(tenant, account, role="normal") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") # Create a real token but with mismatched account ID from extensions.ext_redis import redis_client @@ -3363,7 +3365,7 @@ class TestRegisterService: interface_language="en-US", password=password, ) - TenantService.create_tenant_member(tenant, account, role="normal") + TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") # Change tenant status to non-normal tenant.status = TenantStatus.ARCHIVE diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index 1422afd7524..6ef7f442591 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -15,7 +15,7 @@ from flask import Flask from controllers.console.auth.activate import ActivateApi, ActivateCheckApi from controllers.console.error import AccountInFreezeError, AlreadyActivateError -from models.account import AccountStatus +from models.account import AccountStatus, TenantAccountRole class TestActivateCheckApi: @@ -201,11 +201,11 @@ class TestActivateApi: @patch("controllers.console.auth.activate.db") def test_successful_account_activation( self, - mock_db, - mock_revoke_token, - mock_get_invitation, + mock_db: MagicMock, + mock_revoke_token: MagicMock, + mock_get_invitation: MagicMock, app: Flask, - mock_invitation, + mock_invitation: MagicMock, mock_account, ): """ @@ -448,11 +448,11 @@ class TestActivateApi: @patch("controllers.console.auth.activate.db") def test_activation_returns_success_response( self, - mock_db, - mock_revoke_token, - mock_get_invitation, + mock_db: MagicMock, + mock_revoke_token: MagicMock, + mock_get_invitation: MagicMock, app: Flask, - mock_invitation, + mock_invitation: MagicMock, ): """ Test that activation returns a success response without authentication tokens. @@ -488,11 +488,11 @@ class TestActivateApi: @patch("controllers.console.auth.activate.db") def test_activation_without_workspace_id( self, - mock_db, - mock_revoke_token, - mock_get_invitation, + mock_db: MagicMock, + mock_revoke_token: MagicMock, + mock_get_invitation: MagicMock, app: Flask, - mock_invitation, + mock_invitation: MagicMock, ): """ Test account activation without workspace_id. @@ -528,12 +528,12 @@ class TestActivateApi: @patch("controllers.console.auth.activate.db") def test_activation_normalizes_email_before_lookup( self, - mock_db, - mock_revoke_token, - mock_get_invitation, + mock_db: MagicMock, + mock_revoke_token: MagicMock, + mock_get_invitation: MagicMock, app: Flask, - mock_invitation, - mock_account, + mock_invitation: MagicMock, + mock_account: MagicMock, ): """Ensure uppercase emails are normalized before lookup and revocation.""" mock_get_invitation.return_value = mock_invitation @@ -563,14 +563,14 @@ class TestActivateApi: @patch("controllers.console.auth.activate.db") def test_activation_for_existing_active_account_creates_membership_on_acceptance( self, - mock_db, - mock_revoke_token, - mock_get_invitation, - mock_create_tenant_member, + mock_db: MagicMock, + mock_revoke_token: MagicMock, + mock_get_invitation: MagicMock, + mock_create_tenant_member: MagicMock, app: Flask, - mock_invitation, - mock_account, - mock_switch_tenant, + mock_invitation: MagicMock, + mock_account: MagicMock, + mock_switch_tenant: MagicMock, ): mock_account.status = AccountStatus.ACTIVE mock_invitation["data"]["role"] = "admin" @@ -590,7 +590,9 @@ class TestActivateApi: response = ActivateApi().post() assert response["result"] == "success" - mock_create_tenant_member.assert_called_once_with(mock_invitation["tenant"], mock_account, "admin") + mock_create_tenant_member.assert_called_once_with( + mock_invitation["tenant"], mock_account, mock_db.session, role=TenantAccountRole.ADMIN + ) mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") @@ -600,14 +602,14 @@ class TestActivateApi: @patch("controllers.console.auth.activate.db") def test_activation_legacy_active_member_invitation_does_not_require_setup( self, - mock_db, - mock_revoke_token, - mock_get_invitation, - mock_create_tenant_member, + mock_db: MagicMock, + mock_revoke_token: MagicMock, + mock_get_invitation: MagicMock, + mock_create_tenant_member: MagicMock, app: Flask, - mock_invitation, - mock_account, - mock_switch_tenant, + mock_invitation: MagicMock, + mock_account: MagicMock, + mock_switch_tenant: MagicMock, ): mock_account.status = AccountStatus.ACTIVE mock_get_invitation.return_value = mock_invitation diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 7d2193adc69..25ae0778d4b 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -116,7 +116,9 @@ class TestEnterpriseWorkspace: assert result["tenant"]["id"] == "tenant-id" assert result["tenant"]["name"] == "My Workspace" mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) - mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner") + mock_tenant_svc.create_tenant_member.assert_called_once_with( + mock_tenant, mock_account, mock_db.session, role="owner" + ) mock_event.send.assert_called_once_with(mock_tenant) @patch("controllers.inner_api.workspace.workspace.db") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 0bfa4afb5ba..3b5c6cc9bd6 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -794,7 +794,9 @@ class TestTenantService: mock_db_dependencies["db"].session.add = MagicMock() # Execute test - result = TenantService.create_tenant_member(mock_tenant, mock_account, "normal") + result = TenantService.create_tenant_member( + mock_tenant, mock_account, mock_db_dependencies["db"].session, "normal" + ) # Verify member was created with correct parameters assert result is not None @@ -1483,7 +1485,9 @@ class TestRegisterService: timezone=None, ) mock_create_tenant.assert_called_once_with("Test User's Workspace") - mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner") + mock_create_member.assert_called_once_with( + mock_tenant, mock_account, mock_db_dependencies["db"].session, role="owner" + ) mock_event.send.assert_called_once_with(mock_tenant) self._assert_database_operations_called(mock_db_dependencies["db"]) @@ -1863,7 +1867,9 @@ class TestRegisterService: ) mock_lookup.assert_called_once_with(mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") - mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") + mock_create_member.assert_called_once_with( + mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal" + ) mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) mock_generate_token.assert_called_once_with( mock_tenant, mock_new_account, "normal", requires_setup=True @@ -1910,7 +1916,9 @@ class TestRegisterService: # Verify results assert result == "invite-token-123" - mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") + mock_create_member.assert_called_once_with( + mock_tenant, mock_existing_account, mock_db_dependencies["db"].session, "normal" + ) mock_generate_token.assert_called_once_with( mock_tenant, mock_existing_account, "normal", requires_setup=True ) @@ -2044,7 +2052,7 @@ class TestRegisterService: assert result == "rbac-token" mock_create_member.assert_called_once_with( - mock_tenant, mock_new_account, TenantAccountRole.NORMAL.value + mock_tenant, mock_new_account, mock_db_dependencies["db"].session, TenantAccountRole.NORMAL.value ) mock_rbac_service.MemberRoles.replace.assert_called_once_with( tenant_id=str(mock_tenant.id), @@ -2089,7 +2097,10 @@ class TestRegisterService: assert result == "rbac-token" mock_create_member.assert_called_once_with( - mock_tenant, mock_existing_account, TenantAccountRole.NORMAL.value + mock_tenant, + mock_existing_account, + mock_db_dependencies["db"].session, + TenantAccountRole.NORMAL.value, ) mock_rbac_service.MemberRoles.replace.assert_called_once_with( tenant_id=str(mock_tenant.id), @@ -2133,7 +2144,10 @@ class TestRegisterService: ) mock_create_member.assert_called_once_with( - mock_tenant, mock_existing_account, TenantAccountRole.NORMAL.value + mock_tenant, + mock_existing_account, + mock_db_dependencies["db"].session, + TenantAccountRole.NORMAL.value, ) mock_rbac_service.MemberRoles.replace.assert_called_once_with( tenant_id=str(mock_tenant.id), @@ -2180,7 +2194,9 @@ class TestRegisterService: ) assert result == "legacy-token" - mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "editor") + mock_create_member.assert_called_once_with( + mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "editor" + ) mock_rbac_service.MemberRoles.replace.assert_not_called() # ==================== Token Management Tests ==================== diff --git a/api/tests/unit_tests/services/test_agent_app_feature_service.py b/api/tests/unit_tests/services/test_agent_app_feature_service.py index 5503540356f..3d9337d79fb 100644 --- a/api/tests/unit_tests/services/test_agent_app_feature_service.py +++ b/api/tests/unit_tests/services/test_agent_app_feature_service.py @@ -12,7 +12,6 @@ from typing import Any import pytest -from services import agent_app_feature_service as svc_mod from services.agent_app_feature_service import AgentAppFeatureConfigService TENANT_ID = "11111111-1111-1111-1111-111111111111" @@ -89,9 +88,8 @@ class _FakeWriteSession: class TestUpdateFeatures: - def test_persists_new_app_model_config_version(self, monkeypatch: pytest.MonkeyPatch): + def test_persists_new_app_model_config_version(self): session = _FakeWriteSession() - monkeypatch.setattr(svc_mod.db, "session", session) app_model = SimpleNamespace( tenant_id=TENANT_ID, id="app-1", app_model_config_id=None, updated_by=None, updated_at=None ) @@ -101,6 +99,7 @@ class TestUpdateFeatures: app_model=app_model, # type: ignore[arg-type] account=account, # type: ignore[arg-type] config={"opening_statement": "Hi!", "suggested_questions_after_answer": {"enabled": True}}, + session=session, ) # New row carries the features but no Soul-owned model/prompt/agent_mode. From 517b27c2b4ed5c1b7981e28f7b2611a9bb5b8edf Mon Sep 17 00:00:00 2001 From: Escape0707 Date: Sat, 20 Jun 2026 21:12:14 +0900 Subject: [PATCH 02/70] test: migrate hit testing dump record tests (#37672) --- .../.ruff.toml | 1 - .../pyrefly.toml | 1 - .../services/test_hit_testing_service.py | 335 ++++++++++++++---- .../test_hit_testing_service_dump_records.py | 102 ------ 4 files changed, 264 insertions(+), 175 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_hit_testing_service_dump_records.py diff --git a/api/tests/test_containers_integration_tests/.ruff.toml b/api/tests/test_containers_integration_tests/.ruff.toml index 250cf103ab9..68e3f9af4bd 100644 --- a/api/tests/test_containers_integration_tests/.ruff.toml +++ b/api/tests/test_containers_integration_tests/.ruff.toml @@ -9,7 +9,6 @@ extend-select = ["ANN401", "ARG", "TID251"] "models/test_types_enum_text.py" = ["ANN401", "TID251"] "services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"] "services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"] -"services/test_hit_testing_service.py" = ["ANN401", "TID251"] "trigger/conftest.py" = ["ANN401", "TID251"] "trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"] "controllers/console/app/test_app_apis.py" = ["ARG"] diff --git a/api/tests/test_containers_integration_tests/pyrefly.toml b/api/tests/test_containers_integration_tests/pyrefly.toml index 06ea10036f5..92c84320d9a 100644 --- a/api/tests/test_containers_integration_tests/pyrefly.toml +++ b/api/tests/test_containers_integration_tests/pyrefly.toml @@ -100,7 +100,6 @@ project-excludes = [ "services/test_feature_service.py", "services/test_feedback_service.py", "services/test_file_service.py", - "services/test_hit_testing_service.py", "services/test_human_input_delivery_test.py", "services/test_human_input_delivery_test_service.py", "services/test_message_export_service.py", diff --git a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index f332ba05ec2..2d23ae8f68f 100644 --- a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,26 +1,79 @@ from __future__ import annotations import json -from typing import Any, cast -from unittest.mock import ANY, MagicMock, patch +from datetime import datetime +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from pydantic import BaseModel, ConfigDict, TypeAdapter from sqlalchemy import func, select from sqlalchemy.orm import Session +from core.rag.embedding.retrieval import RetrievalSegments from core.rag.models.document import Document -from models.dataset import Dataset, DatasetQuery +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Document as DatasetDocument +from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus from services.hit_testing_service import HitTestingService -def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: - tenant_id = str(uuid4()) - created_by = str(uuid4()) +class _QueryResponse(BaseModel): + content: str + + +class _RetrieveRecordResponse(BaseModel): + content: str | None = None + title: str | None = None + + model_config = ConfigDict(extra="allow") + + +class _RetrieveResponse(BaseModel): + query: _QueryResponse + records: list[_RetrieveRecordResponse] + + +class _DumpedDocumentResponse(BaseModel): + id: str + data_source_type: str + name: str + doc_type: str | None + doc_metadata: dict[str, object] | None + + +class _DumpedSegmentResponse(BaseModel): + id: str + document_id: str + created_at: datetime | None = None + document: _DumpedDocumentResponse | None = None + + model_config = ConfigDict(extra="allow") + + +class _DumpedRetrievalRecordResponse(BaseModel): + segment: _DumpedSegmentResponse + score: float + + model_config = ConfigDict(extra="allow") + + +_DUMPED_RETRIEVAL_RECORDS = TypeAdapter(list[_DumpedRetrievalRecordResponse]) + + +def _create_dataset( + db_session: Session, + *, + provider: str = "vendor", + tenant_id: str | None = None, + created_by: str | None = None, + name: str = "test-dataset", +) -> Dataset: ds = Dataset( - tenant_id=kwargs.get("tenant_id", tenant_id), - name=kwargs.get("name", "test-dataset"), - created_by=kwargs.get("created_by", created_by), + tenant_id=tenant_id or str(uuid4()), + name=name, + created_by=created_by or str(uuid4()), provider=provider, ) db_session.add(ds) @@ -29,36 +82,106 @@ def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: return ds +def _create_dataset_document( + db_session: Session, + *, + name: str = "guide.md", + data_source_type: str = DataSourceType.UPLOAD_FILE, + doc_type: str | None = None, + doc_metadata: dict[str, object] | None = None, +) -> DatasetDocument: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by, + ) + db_session.add(dataset) + db_session.flush() + + document = DatasetDocument( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type=data_source_type, + batch=f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by, + doc_type=doc_type, + doc_metadata=doc_metadata, + ) + db_session.add(document) + db_session.commit() + db_session.refresh(document) + return document + + +def _build_segment( + *, + document_id: str, + tenant_id: str | None = None, + dataset_id: str | None = None, + created_by: str | None = None, +) -> DocumentSegment: + return DocumentSegment( + tenant_id=tenant_id or str(uuid4()), + dataset_id=dataset_id or str(uuid4()), + document_id=document_id, + created_by=created_by or str(uuid4()), + position=1, + content="segment content", + word_count=2, + tokens=2, + status=SegmentStatus.COMPLETED, + ) + + +def _create_segment(db_session: Session, *, document: DatasetDocument | None = None) -> DocumentSegment: + segment = _build_segment( + tenant_id=document.tenant_id if document else None, + dataset_id=document.dataset_id if document else None, + document_id=document.id if document else str(uuid4()), + created_by=document.created_by if document else None, + ) + db_session.add(segment) + db_session.commit() + db_session.refresh(segment) + return segment + + class TestHitTestingService: # ── Utility methods (pure logic, no DB) ──────────────────────────── - def test_escape_query_for_search_should_escape_double_quotes(self): + def test_escape_query_for_search_should_escape_double_quotes(self) -> None: query = 'test "query" with quotes' result = HitTestingService.escape_query_for_search(query) assert result == 'test \\"query\\" with quotes' - def test_hit_testing_args_check_should_pass_with_valid_query(self): + def test_hit_testing_args_check_should_pass_with_valid_query(self) -> None: HitTestingService.hit_testing_args_check({"query": "valid query"}) - def test_hit_testing_args_check_should_pass_with_valid_attachments(self): + def test_hit_testing_args_check_should_pass_with_valid_attachments(self) -> None: HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) - def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): + def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self) -> None: with pytest.raises(ValueError, match="Query or attachment_ids is required"): HitTestingService.hit_testing_args_check({}) - def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): + def test_hit_testing_args_check_should_raise_error_when_query_too_long(self) -> None: with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): HitTestingService.hit_testing_args_check({"query": "a" * 251}) - def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): + def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self) -> None: with pytest.raises(ValueError, match="Attachment_ids must be a list"): HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") - def test_compact_retrieve_response_should_format_correctly(self, mock_format): + def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None: query = "test query" mock_doc = MagicMock(spec=Document) @@ -66,50 +189,49 @@ class TestHitTestingService: mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) + response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc])) - assert cast(dict[str, Any], result["query"])["content"] == query - assert len(result["records"]) == 1 - assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" + assert response.query.content == query + assert len(response.records) == 1 + assert response.records[0].content == "formatted content" mock_format.assert_called_once_with([mock_doc]) def test_compact_external_retrieve_response_should_return_records_for_external_provider( self, db_session_with_containers: Session - ): + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - result = cast( - dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + response = _RetrieveResponse.model_validate( + HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) ) - assert cast(dict[str, Any], result["query"])["content"] == "test query" - assert len(result["records"]) == 2 - assert cast(dict[str, Any], result["records"][0])["content"] == "c1" - assert cast(dict[str, Any], result["records"][1])["title"] == "t2" + assert response.query.content == "test query" + assert len(response.records) == 2 + assert response.records[0].content == "c1" + assert response.records[1].title == "t2" def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( self, db_session_with_containers: Session - ): + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="vendor") - result = cast( - dict[str, Any], - HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + response = _RetrieveResponse.model_validate( + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]) ) - assert cast(dict[str, Any], result["query"])["content"] == "test query" - assert result["records"] == [] + assert response.query.content == "test query" + assert response.records == [] # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") def test_external_retrieve_should_succeed_for_external_provider( - self, mock_ext_retrieve, db_session_with_containers: Session - ): + self, mock_ext_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="external") account_id = str(uuid4()) account = MagicMock() @@ -118,19 +240,18 @@ class TestHitTestingService: before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 - result = cast( - dict[str, Any], + response = _RetrieveResponse.model_validate( HitTestingService.external_retrieve( dataset=dataset, query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, - ), + ) ) - assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' - assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" + assert response.query.content == 'test "query"' + assert response.records[0].content == "ext content" mock_ext_retrieve.assert_called_once_with( dataset_id=dataset.id, query='test \\"query\\"', @@ -142,37 +263,44 @@ class TestHitTestingService: after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + def test_external_retrieve_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) + response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account)) - assert cast(dict[str, Any], result["query"])["content"] == "test query" - assert result["records"] == [] + assert response.query.content == "test query" + assert response.records == [] # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") def test_retrieve_should_use_default_model_when_none_provided( - self, mock_retrieve, db_session_with_containers: Session - ): + self, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None account = MagicMock() account.id = str(uuid4()) - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents + external_retrieval_model: dict[str, object] = {} before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 - result = cast( - dict[str, Any], + response = _RetrieveResponse.model_validate( HitTestingService.retrieve( - dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} - ), + dataset=dataset, + query="test query", + account=account, + retrieval_model=None, + external_retrieval_model=external_retrieval_model, + ) ) - assert cast(dict[str, Any], result["query"])["content"] == "test query" + assert response.query.content == "test query" mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["top_k"] == 4 @@ -183,11 +311,12 @@ class TestHitTestingService: @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") def test_retrieve_should_handle_metadata_filtering( - self, mock_get_meta, mock_retrieve, db_session_with_containers: Session - ): + self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() account.id = str(uuid4()) + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "semantic_search", @@ -197,14 +326,15 @@ class TestHitTestingService: "score_threshold_enabled": False, } mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, + external_retrieval_model=external_retrieval_model, ) mock_get_meta.assert_called_once() @@ -214,10 +344,11 @@ class TestHitTestingService: @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") def test_retrieve_should_return_empty_if_metadata_filtering_fails( - self, mock_get_meta, mock_retrieve, db_session_with_containers: Session - ): + self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "semantic_search", @@ -226,28 +357,31 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - mock_get_meta.return_value = ({}, "condition_string") + empty_document_ids: dict[str, list[str]] = {} + mock_get_meta.return_value = (empty_document_ids, "condition_string") - result = cast( - dict[str, Any], + response = _RetrieveResponse.model_validate( HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, - ), + external_retrieval_model=external_retrieval_model, + ) ) - assert result["records"] == [] + assert response.records == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + def test_retrieve_should_handle_attachments( + self, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() account.id = str(uuid4()) attachment_ids = ["att1", "att2"] + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "semantic_search", @@ -255,19 +389,20 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, + external_retrieval_model=external_retrieval_model, attachment_ids=attachment_ids, ) mock_retrieve.assert_called_once_with( - retrieval_method=ANY, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, dataset_id=dataset.id, query="test query", attachment_ids=attachment_ids, @@ -295,10 +430,13 @@ class TestHitTestingService: assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + def test_retrieve_should_handle_reranking_and_threshold( + self, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() account.id = str(uuid4()) + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "hybrid_search", @@ -310,14 +448,15 @@ class TestHitTestingService: "score_threshold": 0.5, "weights": {"vector": 0.5, "keyword": 0.5}, } - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, + external_retrieval_model=external_retrieval_model, ) mock_retrieve.assert_called_once() @@ -326,3 +465,57 @@ class TestHitTestingService: assert kwargs["reranking_model"] == {"provider": "test"} assert kwargs["reranking_mode"] == "weighted_sum" assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5} + + def test_dump_dataset_document_returns_frontend_required_fields(self, db_session_with_containers: Session) -> None: + document = _create_dataset_document(db_session_with_containers, doc_metadata={"source": "manual"}) + + assert HitTestingService._dump_dataset_document(document) == { + "id": document.id, + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": {"source": "manual"}, + } + + def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None: + segment = _build_segment(document_id="") + record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) + + records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + + assert len(records) == 1 + assert records[0].segment.id == segment.id + assert records[0].segment.document_id == "" + assert records[0].score == 0.95 + + def test_dump_retrieval_records_injects_documents(self, db_session_with_containers: Session) -> None: + document = _create_dataset_document(db_session_with_containers) + segment = _create_segment(db_session_with_containers, document=document) + record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9}) + + records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + + assert len(records) == 1 + dumped_segment = records[0].segment + assert dumped_segment.id == segment.id + assert dumped_segment.document_id == document.id + assert dumped_segment.created_at == segment.created_at + assert dumped_segment.document == _DumpedDocumentResponse( + id=document.id, + data_source_type="upload_file", + name="guide.md", + doc_type=None, + doc_metadata=None, + ) + assert records[0].score == 0.9 + + def test_dump_retrieval_records_skips_records_with_missing_documents( + self, db_session_with_containers: Session, caplog: pytest.LogCaptureFixture + ) -> None: + segment = _create_segment(db_session_with_containers) + record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) + + result = HitTestingService._dump_retrieval_records([record]) + + assert result == [] + assert "Skipping hit-testing records with missing documents" in caplog.text diff --git a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py b/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py deleted file mode 100644 index 5dd0194fd01..00000000000 --- a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py +++ /dev/null @@ -1,102 +0,0 @@ -from datetime import datetime -from unittest.mock import Mock, patch - -import pytest - -from services.hit_testing_service import HitTestingService - - -def _retrieval_record(payload: dict): - record = Mock() - record.model_dump.return_value = payload - segment = payload.get("segment") - if isinstance(segment, dict): - record.segment = Mock() - record.segment.id = segment.get("id") - record.segment.document_id = segment.get("document_id") - record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0) - else: - record.segment = None - return record - - -def _dataset_document( - document_id: str = "document-1", - name: str = "guide.md", - data_source_type: str = "upload_file", - doc_type: str | None = None, - doc_metadata: dict | None = None, -): - document = Mock() - document.id = document_id - document.name = name - document.data_source_type = data_source_type - document.doc_type = doc_type - document.doc_metadata = doc_metadata - return document - - -class TestHitTestingServiceDumpRecords: - def test_dump_dataset_document_returns_frontend_required_fields(self): - document = _dataset_document(doc_metadata={"source": "manual"}) - - assert HitTestingService._dump_dataset_document(document) == { - "id": "document-1", - "data_source_type": "upload_file", - "name": "guide.md", - "doc_type": None, - "doc_metadata": {"source": "manual"}, - } - - def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self): - record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}) - record.segment.document_id = None - - assert HitTestingService._dump_retrieval_records([record]) == [ - {"segment": {"id": "segment-1", "document_id": None}, "score": 0.95} - ] - - def test_dump_retrieval_records_injects_documents(self): - record_with_document = _retrieval_record( - { - "segment": { - "id": "segment-1", - "document_id": "document-1", - }, - "score": 0.9, - } - ) - scalars_result = Mock() - scalars_result.all.return_value = [_dataset_document()] - - with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): - result = HitTestingService._dump_retrieval_records([record_with_document]) - - assert result[0]["segment"]["document"] == { - "id": "document-1", - "data_source_type": "upload_file", - "name": "guide.md", - "doc_type": None, - "doc_metadata": None, - } - - assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0) - - def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog: pytest.LogCaptureFixture): - record = _retrieval_record( - { - "segment": { - "id": "segment-1", - "document_id": "missing-document", - }, - "score": 0.95, - } - ) - scalars_result = Mock() - scalars_result.all.return_value = [] - - with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): - result = HitTestingService._dump_retrieval_records([record]) - - assert result == [] - assert "Skipping hit-testing records with missing documents" in caplog.text From b3e724dce13918143867f2dfff3c9e4833177049 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Sat, 20 Jun 2026 20:25:08 +0800 Subject: [PATCH 03/70] fix(agent): return conflict for duplicate agent names (#37686) --- api/services/app_service.py | 16 +++++- .../unit_tests/services/test_app_service.py | 57 +++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 0f346433265..941855b8321 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination from pydantic import BaseModel, Field from sqlalchemy import ColumnElement, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, scoped_session from configs import dify_config @@ -27,6 +28,7 @@ from models import Account, AppStar from models.agent import Agent, AgentIconType, AgentScope, AgentSource, AgentStatus from models.model import App, AppMode, AppModelConfig, IconType, Site from models.tools import ApiToolProvider +from services.agent.errors import AgentNameConflictError from services.billing_service import BillingService from services.enterprise import rbac_service as enterprise_rbac_service from services.enterprise.enterprise_service import EnterpriseService @@ -593,6 +595,16 @@ class AppService: if updated_at is not None: agent.updated_at = updated_at + @staticmethod + def _commit_app_identity_update(app: App) -> None: + try: + db.session.commit() + except IntegrityError as exc: + db.session.rollback() + if app.mode == AppMode.AGENT: + raise AgentNameConflictError() from exc + raise + def update_app(self, app: App, args: ArgsDict) -> App: """ Update app @@ -629,7 +641,7 @@ class AppService: account_id=current_user.id, updated_at=app.updated_at, ) - db.session.commit() + self._commit_app_identity_update(app) app_was_updated.send(app) @@ -652,7 +664,7 @@ class AppService: account_id=current_user.id, updated_at=app.updated_at, ) - db.session.commit() + self._commit_app_identity_update(app) app_was_updated.send(app) diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py index e595721e169..c57fb6ed775 100644 --- a/api/tests/unit_tests/services/test_app_service.py +++ b/api/tests/unit_tests/services/test_app_service.py @@ -3,7 +3,11 @@ from __future__ import annotations from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest +from sqlalchemy.exc import IntegrityError + from models.model import App +from services.agent.errors import AgentNameConflictError from services.app_service import AppService @@ -317,6 +321,59 @@ class TestAgentAppType: assert backing_agent.role == "" + def test_update_agent_app_duplicate_name_rolls_back_and_raises_conflict(self): + from models.agent import AgentIconType + from models.model import AppMode, IconType + from services.app_service import AppService + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.AGENT, + name="Old", + description="old", + role="draft", + icon_type=IconType.EMOJI, + icon="robot", + icon_background="#fff", + use_icon_as_answer_icon=False, + max_active_requests=None, + created_by="account-1", + ) + backing_agent = SimpleNamespace( + name="Old", + description="old", + role="research assistant", + icon_type=AgentIconType.EMOJI, + icon="robot", + icon_background="#fff", + updated_by=None, + updated_at=None, + ) + + with ( + patch("services.app_service.db") as mock_db, + patch("services.app_service.current_user", SimpleNamespace(id="account-2")), + ): + mock_db.session.scalar.return_value = backing_agent + mock_db.session.commit.side_effect = IntegrityError("duplicate", None, None) + with pytest.raises(AgentNameConflictError): + AppService().update_app( + app, # type: ignore[arg-type] + { + "name": "Existing Agent", + "description": "agent app", + "role": "research assistant", + "icon_type": "emoji", + "icon": "robot", + "icon_background": "#fff", + "use_icon_as_answer_icon": False, + "max_active_requests": 0, + }, + ) + + mock_db.session.rollback.assert_called_once() + def test_delete_agent_app_archives_backing_agent(self): from models.agent import AgentStatus from models.model import AppMode From dcff1870d5a78246a939e9bf26e1c811cbaa7be5 Mon Sep 17 00:00:00 2001 From: Rohit Gahlawat <283466839+Rohit-Gahlawat@users.noreply.github.com> Date: Sat, 20 Jun 2026 18:05:06 +0530 Subject: [PATCH 04/70] refactor: accept db.session explicitly in SavedMessageService (#37682) --- .../console/explore/saved_message.py | 6 ++-- api/controllers/web/saved_message.py | 9 ++++-- api/services/saved_message_service.py | 22 +++++++------- .../services/test_saved_message_service.py | 30 +++++++++++-------- .../console/explore/test_saved_message.py | 6 ++-- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 3e8f1ce9083..ce43ff18c93 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,6 +11,7 @@ from controllers.console.app.error import AppUnavailableError from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import with_current_user +from extensions.ext_database import db from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from models import Account @@ -37,6 +38,7 @@ class SavedMessageListApi(InstalledAppResource): args = SavedMessageListQuery.model_validate(request.args.to_dict()) pagination = SavedMessageService.pagination_by_last_id( + db.session(), app_model, current_user, str(args.last_id) if args.last_id else None, @@ -63,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource): payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {}) try: - SavedMessageService.save(app_model, current_user, str(payload.message_id)) + SavedMessageService.save(db.session(), app_model, current_user, str(payload.message_id)) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -86,6 +88,6 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, current_user, message_id_str) + SavedMessageService.delete(db.session(), app_model, current_user, message_id_str) return "", 204 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e3baa028e50..6e59a85e2b0 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -9,6 +9,7 @@ from controllers.common.schema import query_params_from_model, register_response from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource +from extensions.ext_database import db from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from models.model import App, EndUser @@ -42,7 +43,9 @@ class SavedMessageListApi(WebApiResource): raw_args = request.args.to_dict() query = SavedMessageListQuery.model_validate(raw_args) - pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) + pagination = SavedMessageService.pagination_by_last_id( + db.session(), app_model, end_user, query.last_id, query.limit + ) adapter = TypeAdapter(SavedMessageItem) items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] return SavedMessageInfiniteScrollPagination( @@ -77,7 +80,7 @@ class SavedMessageListApi(WebApiResource): payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, payload.message_id) + SavedMessageService.save(db.session(), app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -105,6 +108,6 @@ class SavedMessageApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, end_user, message_id_str) + SavedMessageService.delete(db.session(), app_model, end_user, message_id_str) return "", 204 diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 90f01377123..9a65429748e 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,6 +1,6 @@ from sqlalchemy import select +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import CreatorUserRole @@ -12,11 +12,11 @@ from services.message_service import MessageService class SavedMessageService: @classmethod def pagination_by_last_id( - cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int + cls, session: Session, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") - saved_messages = db.session.scalars( + saved_messages = session.scalars( select(SavedMessage) .where( SavedMessage.app_id == app_model.id, @@ -32,10 +32,10 @@ class SavedMessageService: ) @classmethod - def save(cls, app_model: App, user: Account | EndUser | None, message_id: str): + def save(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str): if not user: return - saved_message = db.session.scalar( + saved_message = session.scalar( select(SavedMessage) .where( SavedMessage.app_id == app_model.id, @@ -58,14 +58,14 @@ class SavedMessageService: created_by=user.id, ) - db.session.add(saved_message) - db.session.commit() + session.add(saved_message) + session.commit() @classmethod - def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str): + def delete(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str): if not user: return - saved_message = db.session.scalar( + saved_message = session.scalar( select(SavedMessage) .where( SavedMessage.app_id == app_model.id, @@ -79,5 +79,5 @@ class SavedMessageService: if not saved_message: return - db.session.delete(saved_message) - db.session.commit() + session.delete(saved_message) + session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index ac434021fc8..ad85ac67bc5 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -220,7 +220,9 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination # Act: Execute the method under test - result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10) + result = SavedMessageService.pagination_by_last_id( + db_session_with_containers, app_model=app, user=account, last_id=None, limit=10 + ) # Assert: Verify the expected outcomes assert result is not None @@ -294,7 +296,7 @@ class TestSavedMessageService: # Act: Execute the method under test result = SavedMessageService.pagination_by_last_id( - app_model=app, user=end_user, last_id="test_last_id", limit=5 + db_session_with_containers, app_model=app, user=end_user, last_id="test_last_id", limit=5 ) # Assert: Verify the expected outcomes @@ -344,7 +346,7 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].get_message.return_value = message # Act: Execute the method under test - SavedMessageService.save(app_model=app, user=account, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id) # Assert: Verify the expected outcomes # Check if saved message was created in database @@ -393,7 +395,9 @@ class TestSavedMessageService: # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + SavedMessageService.pagination_by_last_id( + db_session_with_containers, app_model=app, user=None, last_id=None, limit=10 + ) assert "User is required" in str(exc_info.value) @@ -412,7 +416,7 @@ class TestSavedMessageService: message = self._create_test_message(db_session_with_containers, app, account) # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) + result = SavedMessageService.save(db_session_with_containers, app_model=app, user=None, message_id=message.id) # Assert: Verify the expected outcomes assert result is None @@ -471,7 +475,7 @@ class TestSavedMessageService: ) # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=message.id) # Assert: Verify the expected outcomes # Check if saved message was deleted from database @@ -501,7 +505,7 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].get_message.return_value = message - SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=end_user, message_id=message.id) saved = ( db_session_with_containers.query(SavedMessage) @@ -522,9 +526,9 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].get_message.return_value = message # Save once - SavedMessageService.save(app_model=app, user=account, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id) # Save again - SavedMessageService.save(app_model=app, user=account, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id) count = ( db_session_with_containers.query(SavedMessage) @@ -547,7 +551,7 @@ class TestSavedMessageService: db_session_with_containers.add(saved) db_session_with_containers.commit() - SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=None, message_id=message.id) # Should still exist assert ( @@ -566,7 +570,7 @@ class TestSavedMessageService: # Should not raise — use a valid UUID that doesn't exist in DB from uuid import uuid4 - SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=str(uuid4())) def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """Test deleting a saved message for an EndUser.""" @@ -580,7 +584,7 @@ class TestSavedMessageService: db_session_with_containers.add(saved) db_session_with_containers.commit() - SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=end_user, message_id=message.id) assert ( db_session_with_containers.query(SavedMessage) @@ -610,7 +614,7 @@ class TestSavedMessageService: db_session_with_containers.commit() # Delete only account1's saved message - SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=account1, message_id=message.id) # Account's saved message should be gone assert ( diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index f210d0d5d04..ae05b8f6a0e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -63,7 +63,7 @@ class TestSavedMessageListApi: result = method(api, current_user, installed_app) pagination_mock.assert_called_once() - assert pagination_mock.call_args.args[1] is current_user + assert pagination_mock.call_args.args[2] is current_user assert result["limit"] == 20 assert result["has_more"] is False assert len(result["data"]) == 2 @@ -96,7 +96,7 @@ class TestSavedMessageListApi: result = method(api, current_user, installed_app) save_mock.assert_called_once() - assert save_mock.call_args.args[1] is current_user + assert save_mock.call_args.args[2] is current_user assert result == {"result": "success"} def test_post_message_not_exists(self, app: Flask, payload_patch): @@ -136,7 +136,7 @@ class TestSavedMessageApi: result, status = method(api, current_user, installed_app, str(uuid4())) delete_mock.assert_called_once() - assert delete_mock.call_args.args[1] is current_user + assert delete_mock.call_args.args[2] is current_user assert status == 204 assert result == "" From adfd8202204edab820924e78b2a1c2a2dd0e3172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=BD=E5=B0=98?= Date: Sun, 21 Jun 2026 08:47:25 +0800 Subject: [PATCH 05/70] fix(watercrawl): bound client request timeouts (#37515) --- api/core/rag/extractor/watercrawl/client.py | 4 +++- .../core/rag/extractor/watercrawl/test_watercrawl.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 1f4adc0d418..b37bd38ec3e 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -12,6 +12,8 @@ from core.rag.extractor.watercrawl.exceptions import ( WaterCrawlPermissionError, ) +WATERCRAWL_REQUEST_TIMEOUT: httpx.Timeout = httpx.Timeout(30.0, connect=5.0) + class SpiderOptions(TypedDict): max_depth: int @@ -48,7 +50,7 @@ class BaseAPIClient: "User-Agent": "WaterCrawl-Plugin", "Accept-Language": "en-US", } - return httpx.Client(headers=headers, timeout=None) + return httpx.Client(headers=headers, timeout=WATERCRAWL_REQUEST_TIMEOUT) def _request( self, diff --git a/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py index 35e581ccc15..05985d30985 100644 --- a/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py @@ -73,6 +73,9 @@ class TestBaseAPIClient: assert client.session == "session" assert captured["headers"]["X-API-Key"] == "k" assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin" + assert captured["timeout"] is not None + assert captured["timeout"].connect is not None + assert captured["timeout"].read is not None def test_request_stream_and_non_stream_paths(self, monkeypatch: pytest.MonkeyPatch): class FakeSession: From 75d50455d61a21b8cb7c0fb78b8015d2e829f681 Mon Sep 17 00:00:00 2001 From: Rohit Gahlawat <283466839+Rohit-Gahlawat@users.noreply.github.com> Date: Sun, 21 Jun 2026 06:22:36 +0530 Subject: [PATCH 06/70] refactor: accept db.session explicitly in FeedbackService (#37694) --- api/controllers/console/app/message.py | 1 + api/services/feedback_service.py | 5 +-- .../console/app/test_feedback_export_api.py | 1 + .../services/test_feedback_service.py | 31 ++++++++++--------- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 1406fbc634b..9944f02207f 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -338,6 +338,7 @@ class MessageFeedbackExportApi(Resource): try: export_data = FeedbackService.export_feedbacks( + db.session(), app_id=app_model.id, from_source=args.from_source, rating=args.rating, diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index d6c338a830d..24cfb8aa852 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -5,8 +5,8 @@ from datetime import datetime from flask import Response from sqlalchemy import or_, select +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.enums import FeedbackRating from models.model import Account, App, Conversation, Message, MessageFeedback @@ -14,6 +14,7 @@ from models.model import Account, App, Conversation, Message, MessageFeedback class FeedbackService: @staticmethod def export_feedbacks( + session: Session, app_id: str, from_source: str | None = None, rating: str | None = None, @@ -81,7 +82,7 @@ class FeedbackService: stmt = stmt.order_by(MessageFeedback.created_at.desc()) # Execute query - results = db.session.execute(stmt).all() + results = session.execute(stmt).all() # Prepare data for export export_data = [] diff --git a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py index 93310ad3805..f30abb4ed05 100644 --- a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py +++ b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py @@ -289,6 +289,7 @@ class TestFeedbackExportApi: # Verify service was called with correct parameters mock_export_feedbacks.assert_called_once_with( + mock.ANY, app_id=mock_app_model.id, from_source=FeedbackFromSource.USER, rating=FeedbackRating.DISLIKE, diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index a4663450d49..e4fd81b53e7 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -7,7 +7,6 @@ from unittest import mock import pytest -from extensions.ext_database import db from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, Conversation, Message from services.feedback_service import FeedbackService @@ -23,11 +22,9 @@ class TestFeedbackService: """Test FeedbackService methods.""" @pytest.fixture - def mock_db_session(self, monkeypatch: pytest.MonkeyPatch): - """Mock database session.""" - mock_session = mock.Mock() - monkeypatch.setattr(db, "session", mock_session) - return mock_session + def mock_db_session(self): + """Mock database session passed explicitly to the service.""" + return mock.Mock() @pytest.fixture def sample_data(self): @@ -100,7 +97,7 @@ class TestFeedbackService: ) # Test CSV export - result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") # Verify response structure assert hasattr(result, "headers") @@ -131,7 +128,7 @@ class TestFeedbackService: ) # Test JSON export - result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") # Verify response structure assert hasattr(result, "headers") @@ -161,6 +158,7 @@ class TestFeedbackService: # Test with filters result = FeedbackService.export_feedbacks( + mock_db_session, app_id=sample_data["app"].id, from_source=FeedbackFromSource.ADMIN, rating=FeedbackRating.DISLIKE, @@ -177,7 +175,7 @@ class TestFeedbackService: """Test exporting feedback when no data exists.""" mock_db_session.execute.return_value = _execute_result([]) - result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") # Should return an empty CSV with headers only assert hasattr(result, "headers") @@ -195,17 +193,22 @@ class TestFeedbackService: # Test with invalid start_date with pytest.raises(ValueError, match="Invalid start_date format"): - FeedbackService.export_feedbacks(app_id=sample_data["app"].id, start_date="invalid-date-format") + FeedbackService.export_feedbacks( + mock_db_session, app_id=sample_data["app"].id, start_date="invalid-date-format" + ) # Test with invalid end_date with pytest.raises(ValueError, match="Invalid end_date format"): - FeedbackService.export_feedbacks(app_id=sample_data["app"].id, end_date="invalid-date-format") + FeedbackService.export_feedbacks( + mock_db_session, app_id=sample_data["app"].id, end_date="invalid-date-format" + ) def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data): """Test exporting feedback with unsupported format.""" with pytest.raises(ValueError, match="Unsupported format"): FeedbackService.export_feedbacks( + mock_db_session, app_id=sample_data["app"].id, format_type="xml", # Unsupported format ) @@ -236,7 +239,7 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") # Check JSON content json_content = json.loads(result.get_data(as_text=True)) @@ -287,7 +290,7 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") # Check that unicode content is preserved csv_content = result.get_data(as_text=True) @@ -317,7 +320,7 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") # Check JSON content for emoji ratings json_content = json.loads(result.get_data(as_text=True)) From 9b4dd9d4e8d82b57e5837d7ca7224ee902ccb559 Mon Sep 17 00:00:00 2001 From: Rohit Gahlawat <283466839+Rohit-Gahlawat@users.noreply.github.com> Date: Sun, 21 Jun 2026 06:23:36 +0530 Subject: [PATCH 07/70] refactor: accept db.session explicitly in APIBasedExtensionService (#37693) --- api/controllers/console/extension.py | 24 ++++-- api/services/api_based_extension_service.py | 30 +++---- .../console/test_api_based_extension.py | 3 +- .../test_api_based_extension_service.py | 79 +++++++++++-------- .../controllers/console/test_extension.py | 14 ++-- 5 files changed, 85 insertions(+), 65 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 6d9362ae0b1..ec1e01dc460 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -7,6 +7,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE +from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import to_timestamp from libs.login import login_required @@ -126,7 +127,7 @@ class APIBasedExtensionAPI(Resource): def get(self, current_tenant_id: str): return [ _serialize_api_based_extension(extension) - for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id) + for extension in APIBasedExtensionService.get_all_by_tenant_id(db.session(), current_tenant_id) ] @console_ns.doc("create_api_based_extension") @@ -147,7 +148,12 @@ class APIBasedExtensionAPI(Resource): api_key=payload.api_key, ) - return _serialize_saved_api_based_extension(APIBasedExtensionService.save(extension_data), payload.api_key), 201 + return ( + _serialize_saved_api_based_extension( + APIBasedExtensionService.save(db.session(), extension_data), payload.api_key + ), + 201, + ) @console_ns.route("/api-based-extension/") @@ -164,7 +170,7 @@ class APIBasedExtensionDetailAPI(Resource): api_based_extension_id = str(id) return _serialize_api_based_extension( - APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) + APIBasedExtensionService.get_with_tenant_id(db.session(), current_tenant_id, api_based_extension_id) ) @console_ns.doc("update_api_based_extension") @@ -179,7 +185,9 @@ class APIBasedExtensionDetailAPI(Resource): def post(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id( + db.session(), current_tenant_id, api_based_extension_id + ) payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) api_key_for_response = extension_data_from_db.api_key @@ -192,7 +200,7 @@ class APIBasedExtensionDetailAPI(Resource): api_key_for_response = payload.api_key return _serialize_saved_api_based_extension( - APIBasedExtensionService.save(extension_data_from_db), + APIBasedExtensionService.save(db.session(), extension_data_from_db), api_key_for_response, ) @@ -207,8 +215,10 @@ class APIBasedExtensionDetailAPI(Resource): def delete(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id( + db.session(), current_tenant_id, api_based_extension_id + ) - APIBasedExtensionService.delete(extension_data_from_db) + APIBasedExtensionService.delete(db.session(), extension_data_from_db) return "", 204 diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index fdb377694bb..25f554b6bdc 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,16 +1,16 @@ from sqlalchemy import select +from sqlalchemy.orm import Session from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token, encrypt_token -from extensions.ext_database import db from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: @staticmethod - def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: + def get_all_by_tenant_id(session: Session, tenant_id: str) -> list[APIBasedExtension]: extension_list = list( - db.session.scalars( + session.scalars( select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id) .order_by(APIBasedExtension.created_at.desc()) @@ -23,23 +23,23 @@ class APIBasedExtensionService: return extension_list @classmethod - def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: - cls._validation(extension_data) + def save(cls, session: Session, extension_data: APIBasedExtension) -> APIBasedExtension: + cls._validation(session, extension_data) extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) - db.session.add(extension_data) - db.session.commit() + session.add(extension_data) + session.commit() return extension_data @staticmethod - def delete(extension_data: APIBasedExtension): - db.session.delete(extension_data) - db.session.commit() + def delete(session: Session, extension_data: APIBasedExtension): + session.delete(extension_data) + session.commit() @staticmethod - def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.scalar( + def get_with_tenant_id(session: Session, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = session.scalar( select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .limit(1) @@ -53,14 +53,14 @@ class APIBasedExtensionService: return extension @classmethod - def _validation(cls, extension_data: APIBasedExtension): + def _validation(cls, session: Session, extension_data: APIBasedExtension): # name if not extension_data.name: raise ValueError("name must not be empty") if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.scalar( + is_name_existed = session.scalar( select(APIBasedExtension) .where( APIBasedExtension.tenant_id == extension_data.tenant_id, @@ -73,7 +73,7 @@ class APIBasedExtensionService: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.scalar( + is_name_existed = session.scalar( select(APIBasedExtension) .where( APIBasedExtension.tenant_id == extension_data.tenant_id, diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py index 058f4e5fa34..e60558040a5 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py +++ b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py @@ -97,12 +97,13 @@ def test_list_scopes_api_based_extensions_to_authenticated_tenant( assert account_create_response.status_code == 201 APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=foreign_tenant_id, name="Foreign API", api_endpoint="https://foreign.example.com/hook", api_key="foreign-secret-12345", - ) + ), ) response = test_client_with_containers.get( diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index b8e022503fd..8bd4069639f 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -81,7 +81,7 @@ class TestAPIBasedExtensionService: ) # Save extension - saved_extension = APIBasedExtensionService.save(extension_data) + saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Verify extension was saved correctly assert saved_extension.id is not None @@ -119,21 +119,21 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test empty api_endpoint extension_data.name = fake.company() extension_data.api_endpoint = "" with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test empty api_key extension_data.api_endpoint = f"https://{fake.domain_name()}/api" extension_data.api_key = "" with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_get_all_by_tenant_id_success( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -157,11 +157,11 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - saved_extension = APIBasedExtensionService.save(extension_data) + saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) extensions.append(saved_extension) # Get all extensions for tenant - extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id) # Verify results assert len(extension_list) == 3 @@ -191,10 +191,12 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Get extension by ID - retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + retrieved_extension = APIBasedExtensionService.get_with_tenant_id( + db_session_with_containers, tenant.id, created_extension.id + ) # Verify extension was retrieved correctly assert retrieved_extension is not None @@ -219,7 +221,9 @@ class TestAPIBasedExtensionService: # Try to get non-existent extension with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) + APIBasedExtensionService.get_with_tenant_id( + db_session_with_containers, tenant.id, non_existent_extension_id + ) def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -238,11 +242,11 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) extension_id = created_extension.id # Delete the extension - APIBasedExtensionService.delete(created_extension) + APIBasedExtensionService.delete(db_session_with_containers, created_extension) # Verify extension was deleted @@ -270,7 +274,7 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - APIBasedExtensionService.save(extension_data1) + APIBasedExtensionService.save(db_session_with_containers, extension_data1) # Try to create second extension with same name extension_data2 = APIBasedExtension( tenant_id=tenant.id, @@ -280,7 +284,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService.save(extension_data2) + APIBasedExtensionService.save(db_session_with_containers, extension_data2) def test_save_extension_update_existing( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -301,7 +305,7 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Save original values for later comparison original_name = created_extension.name @@ -320,7 +324,7 @@ class TestAPIBasedExtensionService: created_extension.api_endpoint = new_endpoint created_extension.api_key = new_api_key - updated_extension = APIBasedExtensionService.save(created_extension) + updated_extension = APIBasedExtensionService.save(db_session_with_containers, created_extension) # Verify extension was updated correctly assert updated_extension.id == created_extension.id @@ -336,7 +340,9 @@ class TestAPIBasedExtensionService: assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2 # Verify the update by retrieving the extension again - retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + retrieved_extension = APIBasedExtensionService.get_with_tenant_id( + db_session_with_containers, tenant.id, created_extension.id + ) assert retrieved_extension.name == new_name assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved @@ -367,7 +373,7 @@ class TestAPIBasedExtensionService: # Try to save extension with connection error with pytest.raises(ValueError, match="connection error: request timeout"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_invalid_api_key_length( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -390,7 +396,7 @@ class TestAPIBasedExtensionService: # Try to save extension with short API key with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -410,21 +416,21 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test with None api_endpoint extension_data.name = fake.company() extension_data.api_endpoint = None with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test with None api_key extension_data.api_endpoint = f"https://{fake.domain_name()}/api" extension_data.api_key = None with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_get_all_by_tenant_id_empty_list( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -438,7 +444,7 @@ class TestAPIBasedExtensionService: ) # Get all extensions for tenant (none exist) - extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id) # Verify empty list is returned assert len(extension_list) == 0 @@ -468,7 +474,7 @@ class TestAPIBasedExtensionService: # Try to save extension with invalid ping response with pytest.raises(ValueError, match="{'result': 'invalid'}"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_missing_ping_result( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -494,7 +500,7 @@ class TestAPIBasedExtensionService: # Try to save extension with missing ping result with pytest.raises(ValueError, match="{'status': 'ok'}"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_get_with_tenant_id_wrong_tenant( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -520,11 +526,11 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + APIBasedExtensionService.get_with_tenant_id(db_session_with_containers, tenant2.id, created_extension.id) def test_save_extension_api_key_exactly_four_chars_rejected( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -544,7 +550,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_api_key_exactly_five_chars_accepted( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -563,7 +569,7 @@ class TestAPIBasedExtensionService: api_key="12345", ) - saved = APIBasedExtensionService.save(extension_data) + saved = APIBasedExtensionService.save(db_session_with_containers, extension_data) assert saved.id is not None def test_save_extension_requestor_constructor_error( @@ -586,7 +592,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="connection error: bad config"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_network_exception( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -610,7 +616,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="connection error: network failure"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_update_duplicate_name_rejected( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -623,26 +629,28 @@ class TestAPIBasedExtensionService: assert tenant is not None ext1 = APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=tenant.id, name="Extension Alpha", api_endpoint=f"https://{fake.domain_name()}/api", api_key=fake.password(length=20), - ) + ), ) ext2 = APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=tenant.id, name="Extension Beta", api_endpoint=f"https://{fake.domain_name()}/api", api_key=fake.password(length=20), - ) + ), ) # Try to rename ext2 to ext1's name ext2.name = "Extension Alpha" with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService.save(ext2) + APIBasedExtensionService.save(db_session_with_containers, ext2) def test_get_all_returns_empty_for_different_tenant( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -658,14 +666,15 @@ class TestAPIBasedExtensionService: assert tenant1 is not None APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=tenant1.id, name=fake.company(), api_endpoint=f"https://{fake.domain_name()}/api", api_key=fake.password(length=20), - ) + ), ) assert tenant2 is not None - result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + result = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant2.id) assert result == [] diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py index 487cf8f54fd..bab825ca6f0 100644 --- a/api/tests/unit_tests/controllers/console/test_extension.py +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -3,7 +3,7 @@ from __future__ import annotations import builtins import uuid from datetime import UTC, datetime -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest from flask import Flask @@ -114,7 +114,7 @@ def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypat assert response[0]["name"] == "Weather API" assert response[0]["api_endpoint"] == extension.api_endpoint assert response[0]["api_key"].startswith(extension.api_key[:3]) - service_mock.assert_called_once_with("tenant-123") + service_mock.assert_called_once_with(ANY, "tenant-123") def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch): @@ -132,7 +132,7 @@ def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pyt response, status = APIBasedExtensionAPI().post() args, _ = save_mock.call_args - created_extension: APIBasedExtension = args[0] + created_extension: APIBasedExtension = args[1] assert created_extension.tenant_id == "tenant-123" assert created_extension.name == payload["name"] assert created_extension.api_endpoint == payload["api_endpoint"] @@ -157,7 +157,7 @@ def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatc assert response["id"] == extension.id assert response["name"] == extension.name - service_mock.assert_called_once_with("tenant-123", str(extension_id)) + service_mock.assert_called_once_with(ANY, "tenant-123", str(extension_id)) def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch): @@ -187,7 +187,7 @@ def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkey assert existing_extension.name == payload["name"] assert existing_extension.api_endpoint == payload["api_endpoint"] assert existing_extension.api_key == "keep-me" - save_mock.assert_called_once_with(existing_extension) + save_mock.assert_called_once_with(ANY, existing_extension) assert response["name"] == payload["name"] assert response["api_key"] == _masked_api_key("keep-me") @@ -217,7 +217,7 @@ def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flas response = APIBasedExtensionDetailAPI().post(extension_id) assert existing_extension.api_key == "new-secret" - save_mock.assert_called_once_with(existing_extension) + save_mock.assert_called_once_with(ANY, existing_extension) assert response["name"] == payload["name"] assert response["api_key"] == _masked_api_key(payload["api_key"]) @@ -239,6 +239,6 @@ def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeyp ): response, status = APIBasedExtensionDetailAPI().delete(extension_id) - delete_mock.assert_called_once_with(existing_extension) + delete_mock.assert_called_once_with(ANY, existing_extension) assert status == 204 assert response == "" From a8e3257f43039ef212c50e7569f0e8b8a86ca5a7 Mon Sep 17 00:00:00 2001 From: Rohit Gahlawat <283466839+Rohit-Gahlawat@users.noreply.github.com> Date: Sun, 21 Jun 2026 10:48:28 +0530 Subject: [PATCH 08/70] refactor: accept db.session explicitly in FileService.get_upload_files_by_ids (#37695) --- api/services/dataset_service.py | 4 ++-- api/services/file_service.py | 7 ++++--- .../services/test_file_service_zip_and_lookup.py | 6 +++--- api/tests/unit_tests/services/test_file_service.py | 14 +++++++------- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 125f3a8e6b8..a8f341fdd04 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1778,7 +1778,7 @@ class DocumentService: invalid_source_message="Document does not have an uploaded file to download.", missing_file_message="Uploaded file not found.", ) - upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id]) + upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), document.tenant_id, [upload_file_id]) upload_file = upload_files_by_id.get(upload_file_id) if not upload_file: raise NotFound("Uploaded file not found.") @@ -1817,7 +1817,7 @@ class DocumentService: upload_file_ids.append(upload_file_id) upload_file_ids_by_document_id[document_id] = upload_file_id - upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids) + upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), tenant_id, upload_file_ids) missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys()) if missing_upload_file_ids: raise NotFound("Only uploaded-file documents can be downloaded as ZIP.") diff --git a/api/services/file_service.py b/api/services/file_service.py index 1781f0c9727..e41d74ad3eb 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,7 +20,6 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from graphon.file import helpers as file_helpers @@ -268,7 +267,9 @@ class FileService: session.delete(upload_file) @staticmethod - def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]: + def get_upload_files_by_ids( + session: Session, tenant_id: str, upload_file_ids: Sequence[str] + ) -> dict[str, UploadFile]: """ Fetch `UploadFile` rows for a tenant in a single batch query. @@ -282,7 +283,7 @@ class FileService: unique_upload_file_ids: list[str] = list(set(upload_file_id_list)) # Fetch upload files in one query for efficient batch access. - upload_files: Sequence[UploadFile] = db.session.scalars( + upload_files: Sequence[UploadFile] = session.scalars( select(UploadFile).where( UploadFile.tenant_id == tenant_id, UploadFile.id.in_(unique_upload_file_ids), diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py index 1101d834a0d..5eb84f805aa 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -69,7 +69,7 @@ def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers: Session) -> None: """Ensure empty input returns an empty mapping without hitting the database.""" - assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + assert FileService.get_upload_files_by_ids(db_session_with_containers, str(uuid4()), []) == {} def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers: Session) -> None: @@ -78,7 +78,7 @@ def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_contai file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") - result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_id, [file1.id, file1.id, file2.id]) assert set(result.keys()) == {file1.id, file2.id} assert result[file1.id].id == file1.id @@ -92,6 +92,6 @@ def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers: S file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") - result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_a, [file_a.id]) assert set(result.keys()) == {file_a.id} diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index 2e6ca7dbb9c..b81fb823949 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -375,19 +375,19 @@ class TestFileService: file_service.delete_file("file_id") # Should return without doing anything - @patch("services.file_service.db") - def test_get_upload_files_by_ids_empty(self, mock_db): - result = FileService.get_upload_files_by_ids("tenant_id", []) + def test_get_upload_files_by_ids_empty(self): + session = MagicMock() + result = FileService.get_upload_files_by_ids(session, "tenant_id", []) assert result == {} - @patch("services.file_service.db") - def test_get_upload_files_by_ids(self, mock_db): + def test_get_upload_files_by_ids(self): upload_file = MagicMock(spec=UploadFile) upload_file.id = "550e8400-e29b-41d4-a716-446655440000" upload_file.tenant_id = "tenant_id" - mock_db.session.scalars().all.return_value = [upload_file] + session = MagicMock() + session.scalars().all.return_value = [upload_file] - result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) + result = FileService.get_upload_files_by_ids(session, "tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file def test_sanitize_zip_entry_name(self): From b60f83e308e94449fff95b522c9da3d81f2d302d Mon Sep 17 00:00:00 2001 From: MeloMei Date: Sun, 21 Jun 2026 13:31:52 +0800 Subject: [PATCH 09/70] refactor(test): replace logger mock with caplog in billing and vector service tests (#37697) Signed-off-by: MeloMei --- .../services/test_billing_service.py | 34 +++++++-------- .../services/test_vector_service.py | 41 +++++++++---------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 67d1cc02913..f244b69407f 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -14,6 +14,7 @@ Tests follow the Arrange-Act-Assert pattern for clarity. """ import json +import logging from unittest.mock import MagicMock, patch import httpx @@ -170,7 +171,9 @@ class TestBillingServiceSendRequest: @pytest.mark.parametrize( "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] ) - def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code): + def test_delete_request_non_200_with_valid_json( + self, mock_httpx_request, mock_billing_config, status_code, caplog: pytest.LogCaptureFixture + ): """Test DELETE request with non-200 status code raises ValueError. DELETE now checks status code and raises ValueError for non-200 responses. @@ -184,13 +187,11 @@ class TestBillingServiceSendRequest: mock_httpx_request.return_value = mock_response # Act & Assert - with patch("services.billing_service.logger") as mock_logger: + with caplog.at_level(logging.ERROR, logger="services.billing_service"): with pytest.raises(ValueError) as exc_info: BillingService._send_request("DELETE", "/test", json={"key": "value"}) assert "Unable to process delete request" in str(exc_info.value) - # Verify error logging - mock_logger.error.assert_called_once() - assert "DELETE response" in str(mock_logger.error.call_args) + assert "DELETE response" in caplog.text @pytest.mark.parametrize( "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] @@ -213,7 +214,9 @@ class TestBillingServiceSendRequest: @pytest.mark.parametrize( "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] ) - def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code): + def test_delete_request_non_200_with_invalid_json( + self, mock_httpx_request, mock_billing_config, status_code, caplog: pytest.LogCaptureFixture + ): """Test DELETE request with non-200 status code raises ValueError before JSON parsing. DELETE now checks status code before calling response.json(), so ValueError is raised @@ -227,13 +230,11 @@ class TestBillingServiceSendRequest: mock_httpx_request.return_value = mock_response # Act & Assert - with patch("services.billing_service.logger") as mock_logger: + with caplog.at_level(logging.ERROR, logger="services.billing_service"): with pytest.raises(ValueError) as exc_info: BillingService._send_request("DELETE", "/test", json={"key": "value"}) assert "Unable to process delete request" in str(exc_info.value) - # Verify error logging - mock_logger.error.assert_called_once() - assert "DELETE response" in str(mock_logger.error.call_args) + assert "DELETE response" in caplog.text def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config): """Test that _send_request retries on httpx.RequestError.""" @@ -1511,7 +1512,7 @@ class TestBillingServiceSubscriptionOperations: assert isinstance(result["tenant-1"]["expiration_date"], int) assert result["tenant-1"]["expiration_date"] == 1735689600 - def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request, caplog: pytest.LogCaptureFixture): """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" # Arrange tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"] @@ -1526,7 +1527,7 @@ class TestBillingServiceSubscriptionOperations: } # Act - with patch("services.billing_service.logger") as mock_logger: + with caplog.at_level(logging.ERROR, logger="services.billing_service"): result = BillingService.get_plan_bulk(tenant_ids) # Assert - should only contain valid tenants @@ -1542,10 +1543,11 @@ class TestBillingServiceSubscriptionOperations: assert result["tenant-valid-2"]["expiration_date"] == 1767225600 # Verify exception was logged for the invalid tenant - mock_logger.exception.assert_called_once() - log_call_args = mock_logger.exception.call_args[0] - assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0] - assert "tenant-invalid" in log_call_args[1] + exception_records = [r for r in caplog.records if r.levelname == "ERROR"] + assert len(exception_records) == 1 + formatted = exception_records[0].getMessage() + assert "get_plan_bulk: failed to validate subscription plan for tenant" in formatted + assert "tenant-invalid" in formatted def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): """Test successful retrieval of expired subscription cleanup whitelist.""" diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index a78a033f4d3..e6cc59144b3 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import Any from unittest.mock import MagicMock @@ -268,7 +269,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues( - monkeypatch: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) segment = _make_segment() @@ -280,18 +281,16 @@ def test_create_segments_vector_parent_child_missing_document_logs_warning_and_c _mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule), ) - logger_mock = MagicMock() - monkeypatch.setattr(vector_service_module, "logger", logger_mock) - index_processor = MagicMock() factory_instance = MagicMock() factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector( - None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX - ) - logger_mock.warning.assert_called_once() + with caplog.at_level(logging.WARNING, logger="services.vector_service"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + assert "Expected DatasetDocument record to exist, but none was found" in caplog.text index_processor.load.assert_not_called() @@ -615,7 +614,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( - monkeypatch: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) @@ -630,12 +629,10 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up monkeypatch.setattr(vector_service_module, "delete", MagicMock()) monkeypatch.setattr(vector_service_module, "select", MagicMock()) - logger_mock = MagicMock() - monkeypatch.setattr(vector_service_module, "logger", logger_mock) + with caplog.at_level(logging.WARNING, logger="services.vector_service"): + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset) - VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset) - - logger_mock.warning.assert_called_once() + assert "Upload file not found for attachment_id" in caplog.text db_mock.session.add_all.assert_called_once() bindings = db_mock.session.add_all.call_args.args[0] assert len(bindings) == 1 @@ -673,7 +670,9 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops db_mock.session.commit.assert_called_once() -def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: +def test_update_multimodel_vector_rolls_back_and_reraises_on_error( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) @@ -688,11 +687,11 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: monkeypatch.setattr(vector_service_module, "delete", MagicMock()) monkeypatch.setattr(vector_service_module, "select", MagicMock()) - logger_mock = MagicMock() - monkeypatch.setattr(vector_service_module, "logger", logger_mock) + with caplog.at_level(logging.ERROR, logger="services.vector_service"): + with pytest.raises(RuntimeError, match="boom"): + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) - with pytest.raises(RuntimeError, match="boom"): - VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) - - logger_mock.exception.assert_called_once() + exception_records = [r for r in caplog.records if r.levelname == "ERROR"] + assert len(exception_records) == 1 + assert "Failed to update multimodal vector for segment" in exception_records[0].getMessage() db_mock.session.rollback.assert_called_once() From 44464c8c6395a04a17f645804d3c779d8df2b7e7 Mon Sep 17 00:00:00 2001 From: frank Date: Sun, 21 Jun 2026 15:30:40 +0800 Subject: [PATCH 10/70] test: replace patch logger with caplog in core/rag tests (#37468) (#37621) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../rag/embedding/test_embedding_service.py | 9 +-- .../test_paragraph_index_processor.py | 63 ++++++++++--------- .../processor/test_qa_index_processor.py | 9 ++- 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 4b8175b0b42..42d5ea4a393 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -44,6 +44,7 @@ Tests follow the Arrange-Act-Assert pattern for clarity. """ import base64 +import logging from decimal import Decimal from unittest.mock import Mock, patch @@ -406,7 +407,7 @@ class TestCacheEmbeddingDocuments: assert len(calls[1].kwargs["texts"]) == 10 assert len(calls[2].kwargs["texts"]) == 5 - def test_embed_documents_nan_handling(self, mock_model_instance): + def test_embed_documents_nan_handling(self, mock_model_instance, caplog): """Test handling of NaN values in embeddings. Verifies: @@ -446,7 +447,7 @@ class TestCacheEmbeddingDocuments: mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result - with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with caplog.at_level(logging.WARNING, logger="core.rag.embedding.cached_embedding"): # Act result = cache_embedding.embed_documents(texts) @@ -461,8 +462,8 @@ class TestCacheEmbeddingDocuments: assert result[1] is None # Verify warning was logged - mock_logger.warning.assert_called_once() - assert "Normalized embedding is nan" in str(mock_logger.warning.call_args) + assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) >= 1 + assert any("Normalized embedding is nan" in record.message for record in caplog.records) def test_embed_documents_api_connection_error(self, mock_model_instance): """Test handling of API connection errors during embedding. diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 182930b19d1..d2154f138a7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -1,3 +1,4 @@ +import logging from types import SimpleNamespace from typing import Any from unittest.mock import Mock, patch @@ -384,7 +385,7 @@ class TestParagraphIndexProcessor: with pytest.raises(ValueError, match="model_name and model_provider_name"): ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True}) - def test_generate_summary_text_only_flow(self) -> None: + def test_generate_summary_text_only_flow(self, caplog) -> None: model_instance = Mock() model_instance.credentials = {"k": "v"} model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) @@ -402,19 +403,22 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota", side_effect=RuntimeError("quota"), ), - patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() - summary, usage = ParagraphIndexProcessor.generate_summary( - "tenant-1", - "text content", - {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, - document_language="English", - ) + with caplog.at_level( + logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor" + ): + summary, usage = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + document_language="English", + ) assert summary == "text summary" assert isinstance(usage, LLMUsage) - mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota") + assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1 + assert any("Failed to deduct quota for summary generation" in record.message for record in caplog.records) def test_generate_summary_handles_vision_and_image_conversion(self) -> None: model_instance = Mock() @@ -455,7 +459,7 @@ class TestParagraphIndexProcessor: assert summary == "vision summary" mock_extract_text.assert_not_called() - def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None: + def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog) -> None: model_instance = Mock() model_instance.credentials = {"k": "v"} model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( @@ -482,21 +486,24 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", side_effect=RuntimeError("bad image"), ), - patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() with pytest.raises(ValueError, match="Expected LLMResult"): - ParagraphIndexProcessor.generate_summary( - "tenant-1", - "text content", - {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, - ) + with caplog.at_level( + logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor" + ): + ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + ) - mock_logger.warning.assert_called_with( - "Failed to convert image file to prompt message content: %s", "bad image" + assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1 + assert any( + "Failed to convert image file to prompt message content" in record.message for record in caplog.records ) - def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None: + def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog) -> None: text = ( "![img](/files/11111111-1111-1111-1111-111111111111/image-preview) " "![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) " @@ -532,13 +539,13 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", return_value=SimpleNamespace(id="file-1"), ) as mock_builder, - patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"), ): files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session) assert len(files) == 1 assert mock_builder.call_count == 1 - mock_logger.warning.assert_not_called() + assert not any(record.levelno == logging.WARNING for record in caplog.records) def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None: scalars_result = Mock() @@ -547,7 +554,7 @@ class TestParagraphIndexProcessor: session.scalars.return_value = scalars_result assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here", session) == [] - def test_extract_images_from_text_logs_when_build_fails(self) -> None: + def test_extract_images_from_text_logs_when_build_fails(self, caplog) -> None: text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)" image_upload = SimpleNamespace( id="11111111-1111-1111-1111-111111111111", @@ -569,14 +576,14 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", side_effect=RuntimeError("build failed"), ), - patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"), ): files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session) assert files == [] - mock_logger.warning.assert_called_once() + assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1 - def test_extract_images_from_segment_attachments(self) -> None: + def test_extract_images_from_segment_attachments(self, caplog) -> None: image_upload = SimpleNamespace( id="file-1", name="image", @@ -609,13 +616,11 @@ class TestParagraphIndexProcessor: session = Mock() session.execute.return_value = execute_result - with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, - ): + with caplog.at_level(logging.WARNING, logger="core.rag.index_processor.processor.paragraph_index_processor"): files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1", session) assert len(files) == 1 - mock_logger.warning.assert_called_once() + assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1 def test_extract_images_from_segment_attachments_empty(self) -> None: execute_result = Mock() diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 30600e64651..4ffd0a76433 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -1,3 +1,4 @@ +import logging from types import SimpleNamespace from typing import Any from unittest.mock import MagicMock, Mock, patch @@ -350,7 +351,7 @@ class TestQAIndexProcessor: assert all_qa_documents[0].metadata["answer"] == "A test." assert all_qa_documents[1].metadata["answer"] == "Coverage." - def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None: + def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app, caplog) -> None: all_qa_documents: list[Document] = [] source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) @@ -359,12 +360,14 @@ class TestQAIndexProcessor: "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", side_effect=RuntimeError("llm failure"), ), - patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger, + caplog.at_level(logging.ERROR, logger="core.rag.index_processor.processor.qa_index_processor"), ): processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") assert all_qa_documents == [] - mock_logger.exception.assert_called_once_with("Failed to format qa document") + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "ERROR" + assert "Failed to format qa document" in caplog.records[0].message def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None: parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n") From 3a3ad6ad7ce779dd37cfe846e59f1890e050888a Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sun, 21 Jun 2026 15:32:15 +0800 Subject: [PATCH 11/70] fix: skip empty tool entries in legacy dataset config extraction (#37669) Signed-off-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> --- .../easy_ui_based_app/dataset/manager.py | 5 ++++ .../easy_ui_based_app/test_dataset_manager.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 3d857a4e9c0..be538455afb 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -213,6 +213,11 @@ class DatasetConfigManager: PlanningStrategy.REACT_ROUTER, }: for tool in config.get("agent_mode", {}).get("tools", []): + if not tool: + # Skip malformed empty tool entries; list(tool.keys())[0] + # would otherwise raise IndexError. The sibling convert() + # already guards this with `if len(tool) == 1`. + continue key = list(tool.keys())[0] if key == "dataset": # old style, use tool name as key diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py index d5305d2fc0b..e4e4f99c6d4 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py @@ -318,3 +318,26 @@ class TestIsDatasetExists: return_value=mock_dataset, ) assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + +# ============================== +# extract_dataset_config_for_legacy_compatibility tests +# ============================== + + +class TestExtractDatasetConfigForLegacyCompatibility: + def test_skips_empty_tool_entry(self): + # A malformed empty tool dict in agent_mode.tools must be skipped, not + # crash with `IndexError` on `list(tool.keys())[0]`. The sibling + # convert() already guards this with `if len(tool) == 1`. + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{}], + } + } + + result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + assert result["agent_mode"]["tools"] == [{}] From 24080010c9a1dcd877492b28994140bd66ddc69c Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 22 Jun 2026 10:17:07 +0800 Subject: [PATCH 12/70] chore(deps): bump base-ui to v1.6.0 (#37663) --- packages/dify-ui/src/autocomplete/index.tsx | 6 ++--- packages/dify-ui/src/button/index.tsx | 2 +- packages/dify-ui/src/collapsible/index.tsx | 2 +- packages/dify-ui/src/combobox/index.tsx | 8 +++---- packages/dify-ui/src/context-menu/index.tsx | 2 +- packages/dify-ui/src/dialog/index.tsx | 7 +++--- packages/dify-ui/src/dropdown-menu/index.tsx | 2 +- .../src/file-tree/__tests__/index.spec.tsx | 2 +- packages/dify-ui/src/file-tree/index.tsx | 4 ++-- .../src/number-field/__tests__/index.spec.tsx | 4 ++-- .../src/number-field/index.stories.tsx | 5 ---- packages/dify-ui/src/number-field/index.tsx | 4 ++-- packages/dify-ui/src/overlay-shared.ts | 2 +- .../dify-ui/src/pagination/index.stories.tsx | 8 ------- packages/dify-ui/src/pagination/index.tsx | 2 +- .../src/scroll-area/__tests__/index.spec.tsx | 8 +++++++ .../dify-ui/src/scroll-area/index.stories.tsx | 2 +- packages/dify-ui/src/scroll-area/index.tsx | 2 +- .../src/segmented-control/index.stories.tsx | 6 ----- packages/dify-ui/src/select/index.tsx | 6 ++--- packages/dify-ui/src/tooltip/index.tsx | 2 +- pnpm-lock.yaml | 24 +++++++++---------- pnpm-workspace.yaml | 2 +- 23 files changed, 50 insertions(+), 62 deletions(-) diff --git a/packages/dify-ui/src/autocomplete/index.tsx b/packages/dify-ui/src/autocomplete/index.tsx index b362a9450fd..4d115024ab1 100644 --- a/packages/dify-ui/src/autocomplete/index.tsx +++ b/packages/dify-ui/src/autocomplete/index.tsx @@ -145,9 +145,9 @@ const autocompleteControlVariants = cva( { variants: { size: { - small: 'mr-1 size-4', - medium: 'mr-1.5 size-5', - large: 'mr-2 size-5', + small: 'me-1 size-4', + medium: 'me-1.5 size-5', + large: 'me-2 size-5', }, }, defaultVariants: { diff --git a/packages/dify-ui/src/button/index.tsx b/packages/dify-ui/src/button/index.tsx index 2181b880a55..0d36d7d1510 100644 --- a/packages/dify-ui/src/button/index.tsx +++ b/packages/dify-ui/src/button/index.tsx @@ -131,7 +131,7 @@ export function Button({ {children} {loading && (