diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 8ebb745a60..39fc7dec6b 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,9 +1,9 @@ -from flask_login import current_user from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required -from libs.login import login_required +from libs.login import current_user, login_required +from models.model import Account from services.billing_service import BillingService @@ -17,9 +17,10 @@ class Subscription(Resource): parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() + assert isinstance(current_user, Account) BillingService.is_tenant_owner_or_admin(current_user) - + assert current_user.current_tenant_id is not None return BillingService.get_subscription( args["plan"], args["interval"], current_user.email, current_user.current_tenant_id ) @@ -31,7 +32,9 @@ class Invoices(Resource): @account_initialization_required @only_edition_cloud def get(self): + assert isinstance(current_user, Account) BillingService.is_tenant_owner_or_admin(current_user) + assert current_user.current_tenant_id is not None return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 76267a2fe1..8578f38a0d 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -2,7 +2,6 @@ import threading from typing import Any, Optional import pytz -from flask_login import current_user import contexts from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager @@ -10,6 +9,7 @@ from core.plugin.impl.agent import PluginAgentClient from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.tool_manager import ToolManager from extensions.ext_database import db +from libs.login import current_user from models.account import Account from models.model import App, Conversation, EndUser, Message, MessageAgentThought @@ -61,7 +61,8 @@ class AgentService: executor = executor.name else: executor = "Unknown" - + assert isinstance(current_user, Account) + assert current_user.timezone is not None timezone = pytz.timezone(current_user.timezone) app_model_config = app_model.app_model_config diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 24567cc34c..ba86a31240 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -2,7 +2,6 @@ import uuid from typing import Optional import pandas as pd -from flask_login import current_user from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now +from libs.login import current_user +from models.account import Account from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from services.feature_service import FeatureService from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task @@ -24,6 +25,7 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info + assert isinstance(current_user, Account) app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -62,6 +64,7 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , add annotation to index annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + assert current_user.current_tenant_id is not None if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -84,6 +87,8 @@ class AppAnnotationService: enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(enable_app_annotation_job_key, "waiting") + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None enable_annotation_reply_task.delay( str(job_id), app_id, @@ -97,6 +102,8 @@ class AppAnnotationService: @classmethod def disable_app_annotation(cls, app_id: str): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: @@ -113,6 +120,8 @@ class AppAnnotationService: @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -145,6 +154,8 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -164,6 +175,8 @@ class AppAnnotationService: @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -193,6 +206,8 @@ class AppAnnotationService: @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -230,6 +245,8 @@ class AppAnnotationService: @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -269,6 +286,8 @@ class AppAnnotationService: @classmethod def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -317,6 +336,8 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): # get app info + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") @@ -355,6 +376,8 @@ class AppAnnotationService: @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # get app info app = ( db.session.query(App) @@ -425,6 +448,8 @@ class AppAnnotationService: @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # get app info app = ( db.session.query(App) @@ -451,6 +476,8 @@ class AppAnnotationService: @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # get app info app = ( db.session.query(App) @@ -491,6 +518,8 @@ class AppAnnotationService: @classmethod def clear_all_annotations(cls, app_id: str): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") diff --git a/api/services/app_service.py b/api/services/app_service.py index 09aab5f0c4..9b200a570d 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -2,7 +2,6 @@ import json import logging from typing import Optional, TypedDict, cast -from flask_login import current_user from flask_sqlalchemy.pagination import Pagination from configs import dify_config @@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from libs.login import current_user from models.account import Account from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider @@ -168,6 +168,8 @@ class AppService: """ Get App """ + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # get original app model config if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: model_config = app.app_model_config @@ -242,6 +244,7 @@ class AppService: :param args: request args :return: App instance """ + assert current_user is not None app.name = args["name"] app.description = args["description"] app.icon_type = args["icon_type"] @@ -262,6 +265,7 @@ class AppService: :param name: new name :return: App instance """ + assert current_user is not None app.name = name app.updated_by = current_user.id app.updated_at = naive_utc_now() @@ -277,6 +281,7 @@ class AppService: :param icon_background: new icon_background :return: App instance """ + assert current_user is not None app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id @@ -294,7 +299,7 @@ class AppService: """ if enable_site == app.enable_site: return app - + assert current_user is not None app.enable_site = enable_site app.updated_by = current_user.id app.updated_at = naive_utc_now() @@ -311,6 +316,7 @@ class AppService: """ if enable_api == app.enable_api: return app + assert current_user is not None app.enable_api = enable_api app.updated_by = current_user.id diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 40d45af376..066bed3234 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -70,7 +70,7 @@ class BillingService: return response.json() @staticmethod - def is_tenant_owner_or_admin(current_user): + def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id join: Optional[TenantAccountJoin] = ( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index c0c97fbd77..2b151f9a8e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -8,7 +8,7 @@ import uuid from collections import Counter from typing import Any, Literal, Optional -from flask_login import current_user +import sqlalchemy as sa from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -27,6 +27,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper from libs.datetime_utils import naive_utc_now +from libs.login import current_user from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -498,8 +499,11 @@ class DatasetService: data: Update data dictionary filtered_data: Filtered update data to modify """ + # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: model_manager = ModelManager() + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=data["embedding_model_provider"], @@ -611,8 +615,12 @@ class DatasetService: data: Update data dictionary filtered_data: Filtered update data to modify """ + # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None + model_manager = ModelManager() try: + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=data["embedding_model_provider"], @@ -720,6 +728,8 @@ class DatasetService: @staticmethod def get_dataset_auto_disable_logs(dataset_id: str): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None features = FeatureService.get_features(current_user.current_tenant_id) if not features.billing.enabled or features.billing.subscription.plan == "sandbox": return { @@ -924,6 +934,8 @@ class DocumentService: @staticmethod def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: + assert isinstance(current_user, Account) + documents = ( db.session.query(Document) .where( @@ -983,6 +995,8 @@ class DocumentService: @staticmethod def rename_document(dataset_id: str, document_id: str, name: str) -> Document: + assert isinstance(current_user, Account) + dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found.") @@ -1012,6 +1026,7 @@ class DocumentService: if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: raise DocumentIndexingError() # update document to be paused + assert current_user is not None document.is_paused = True document.paused_by = current_user.id document.paused_at = naive_utc_now() @@ -1098,6 +1113,9 @@ class DocumentService: # check doc_form DatasetService.check_doc_form(dataset, knowledge_config.doc_form) # check document limit + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None + features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: @@ -1434,6 +1452,8 @@ class DocumentService: @staticmethod def get_tenant_documents_count(): + assert isinstance(current_user, Account) + documents_count = ( db.session.query(Document) .where( @@ -1454,6 +1474,8 @@ class DocumentService: dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): + assert isinstance(current_user, Account) + DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data.original_document_id) if document is None: @@ -1513,7 +1535,7 @@ class DocumentService: data_source_binding = ( db.session.query(DataSourceOauthBinding) .where( - db.and_( + sa.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, @@ -1574,6 +1596,9 @@ class DocumentService: @staticmethod def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None + features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: @@ -2013,6 +2038,9 @@ class SegmentService: @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None + content = args["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) @@ -2075,6 +2103,9 @@ class SegmentService: @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None + lock_name = f"multi_add_segment_lock_document_id_{document.id}" increment_word_count = 0 with redis_client.lock(lock_name, timeout=600): @@ -2158,6 +2189,9 @@ class SegmentService: @classmethod def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None + indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: @@ -2349,6 +2383,7 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + assert isinstance(current_user, Account) segments = ( db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) .where( @@ -2379,6 +2414,8 @@ class SegmentService: def update_segments_status( cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document ): + assert current_user is not None + # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return @@ -2441,6 +2478,8 @@ class SegmentService: def create_child_chunk( cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset ) -> ChildChunk: + assert isinstance(current_user, Account) + lock_name = f"add_child_lock_{segment.id}" with redis_client.lock(lock_name, timeout=20): index_node_id = str(uuid.uuid4()) @@ -2488,6 +2527,8 @@ class SegmentService: document: Document, dataset: Dataset, ) -> list[ChildChunk]: + assert isinstance(current_user, Account) + child_chunks = ( db.session.query(ChildChunk) .where( @@ -2562,6 +2603,8 @@ class SegmentService: document: Document, dataset: Dataset, ) -> ChildChunk: + assert current_user is not None + try: child_chunk.content = content child_chunk.word_count = len(content) @@ -2592,6 +2635,8 @@ class SegmentService: def get_child_chunks( cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None ): + assert isinstance(current_user, Account) + query = ( select(ChildChunk) .filter_by( diff --git a/api/services/file_service.py b/api/services/file_service.py index 4c0a0f451c..8a4655d25e 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,7 +3,6 @@ import os import uuid from typing import Any, Literal, Union -from flask_login import current_user from werkzeug.exceptions import NotFound from configs import dify_config @@ -19,6 +18,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id +from libs.login import current_user from models.account import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -111,6 +111,9 @@ class FileService: @staticmethod def upload_text(text: str, text_name: str) -> UploadFile: + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None + if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index d63b188b12..c572ddc925 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -1,10 +1,11 @@ import json -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest from faker import Faker from core.plugin.impl.exc import PluginDaemonClientSideError +from models.account import Account from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService @@ -21,7 +22,7 @@ class TestAgentService: patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, patch("services.agent_service.ToolManager") as mock_tool_manager, patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, - patch("services.agent_service.current_user") as mock_current_user, + patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, patch("services.app_service.ModelManager") as mock_model_manager, diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 4184420880..3cb7424df8 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -1,9 +1,10 @@ -from unittest.mock import patch +from unittest.mock import create_autospec, patch import pytest from faker import Faker from werkzeug.exceptions import NotFound +from models.account import Account from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService from services.app_service import AppService @@ -24,7 +25,9 @@ class TestAnnotationService: patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, - patch("services.annotation_service.current_user") as mock_current_user, + patch( + "services.annotation_service.current_user", create_autospec(Account, instance=True) + ) as mock_current_user, ): # Setup default mock returns mock_account_feature_service.get_features.return_value.billing.enabled = False diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 69cd9fafee..cbbbbddb21 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -1,9 +1,10 @@ -from unittest.mock import patch +from unittest.mock import create_autospec, patch import pytest from faker import Faker from constants.model_template import default_app_templates +from models.account import Account from models.model import App, Site from services.account_service import AccountService, TenantService from services.app_service import AppService @@ -161,8 +162,13 @@ class TestAppService: app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) - # Get app using the service - retrieved_app = app_service.get_app(created_app) + # Get app using the service - needs current_user mock + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + retrieved_app = app_service.get_app(created_app) # Verify retrieved app matches created app assert retrieved_app.id == created_app.id @@ -406,7 +412,11 @@ class TestAppService: "use_icon_as_answer_icon": True, } - with patch("flask_login.utils._get_user", return_value=account): + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app(app, update_args) # Verify updated fields @@ -456,7 +466,11 @@ class TestAppService: # Update app name new_name = "New App Name" - with patch("flask_login.utils._get_user", return_value=account): + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app_name(app, new_name) assert updated_app.name == new_name @@ -504,7 +518,11 @@ class TestAppService: # Update app icon new_icon = "🌟" new_icon_background = "#FFD93D" - with patch("flask_login.utils._get_user", return_value=account): + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) assert updated_app.icon == new_icon @@ -551,13 +569,17 @@ class TestAppService: original_site_status = app.enable_site # Update site status to disabled - with patch("flask_login.utils._get_user", return_value=account): + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app_site_status(app, False) assert updated_app.enable_site is False assert updated_app.updated_by == account.id # Update site status back to enabled - with patch("flask_login.utils._get_user", return_value=account): + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app_site_status(updated_app, True) assert updated_app.enable_site is True assert updated_app.updated_by == account.id @@ -602,13 +624,17 @@ class TestAppService: original_api_status = app.enable_api # Update API status to disabled - with patch("flask_login.utils._get_user", return_value=account): + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app_api_status(app, False) assert updated_app.enable_api is False assert updated_app.updated_by == account.id # Update API status back to enabled - with patch("flask_login.utils._get_user", return_value=account): + with patch("services.app_service.current_user", mock_current_user): updated_app = app_service.update_app_api_status(updated_app, True) assert updated_app.enable_api is True assert updated_app.updated_by == account.id diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 965c9c6242..5e5e680a5d 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -1,6 +1,6 @@ import hashlib from io import BytesIO -from unittest.mock import patch +from unittest.mock import create_autospec, patch import pytest from faker import Faker @@ -417,11 +417,12 @@ class TestFileService: text = "This is a test text content" text_name = "test_text.txt" - # Mock current_user - with patch("services.file_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = str(fake.uuid4()) - mock_current_user.id = str(fake.uuid4()) + # Mock current_user using create_autospec + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = str(fake.uuid4()) + mock_current_user.id = str(fake.uuid4()) + with patch("services.file_service.current_user", mock_current_user): upload_file = FileService.upload_text(text=text, text_name=text_name) assert upload_file is not None @@ -443,11 +444,12 @@ class TestFileService: text = "test content" long_name = "a" * 250 # Longer than 200 characters - # Mock current_user - with patch("services.file_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = str(fake.uuid4()) - mock_current_user.id = str(fake.uuid4()) + # Mock current_user using create_autospec + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = str(fake.uuid4()) + mock_current_user.id = str(fake.uuid4()) + with patch("services.file_service.current_user", mock_current_user): upload_file = FileService.upload_text(text=text, text_name=long_name) # Verify name was truncated @@ -846,11 +848,12 @@ class TestFileService: text = "" text_name = "empty.txt" - # Mock current_user - with patch("services.file_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = str(fake.uuid4()) - mock_current_user.id = str(fake.uuid4()) + # Mock current_user using create_autospec + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = str(fake.uuid4()) + mock_current_user.id = str(fake.uuid4()) + with patch("services.file_service.current_user", mock_current_user): upload_file = FileService.upload_text(text=text, text_name=text_name) assert upload_file is not None diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 7fef572c14..4646531a4e 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import create_autospec, patch import pytest from faker import Faker @@ -17,7 +17,9 @@ class TestMetadataService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.metadata_service.current_user") as mock_current_user, + patch( + "services.metadata_service.current_user", create_autospec(Account, instance=True) + ) as mock_current_user, patch("services.metadata_service.redis_client") as mock_redis_client, patch("services.dataset_service.DocumentService") as mock_document_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 2d5cdf426d..d09a4a17ab 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import create_autospec, patch import pytest from faker import Faker @@ -17,7 +17,7 @@ class TestTagService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.tag_service.current_user") as mock_current_user, + patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, ): # Setup default mock returns mock_current_user.current_tenant_id = "test-tenant-id" diff --git a/api/tests/test_containers_integration_tests/services/test_website_service.py b/api/tests/test_containers_integration_tests/services/test_website_service.py index ec2f1556af..5ac9ce820a 100644 --- a/api/tests/test_containers_integration_tests/services/test_website_service.py +++ b/api/tests/test_containers_integration_tests/services/test_website_service.py @@ -1,5 +1,5 @@ from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest from faker import Faker @@ -231,9 +231,10 @@ class TestWebsiteService: fake = Faker() # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlApiRequest( provider="firecrawl", @@ -285,9 +286,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlApiRequest( provider="watercrawl", @@ -336,9 +338,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request for single page crawling api_request = WebsiteCrawlApiRequest( provider="jinareader", @@ -389,9 +392,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request with invalid provider api_request = WebsiteCrawlApiRequest( provider="invalid_provider", @@ -419,9 +423,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") @@ -463,9 +468,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") @@ -502,9 +508,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") @@ -544,9 +551,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request with invalid provider api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") @@ -569,9 +577,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Mock missing credentials mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None @@ -597,9 +606,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Mock missing API key in config mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { "config": {"base_url": "https://api.example.com"} @@ -995,9 +1005,10 @@ class TestWebsiteService: account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request for sub-page crawling api_request = WebsiteCrawlApiRequest( provider="jinareader", @@ -1054,9 +1065,10 @@ class TestWebsiteService: mock_external_service_dependencies["requests"].get.return_value = mock_failed_response # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlApiRequest( provider="jinareader", @@ -1096,9 +1108,10 @@ class TestWebsiteService: mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance # Mock current_user for the test - with patch("services.website_service.current_user") as mock_current_user: - mock_current_user.current_tenant_id = account.current_tenant.id + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.current_tenant_id = account.current_tenant.id + with patch("services.website_service.current_user", mock_current_user): # Create API request api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 7c40b1e556..fb23863043 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -2,11 +2,12 @@ import datetime from typing import Any, Optional # Mock redis_client before importing dataset_service -from unittest.mock import Mock, patch +from unittest.mock import Mock, create_autospec, patch import pytest from core.model_runtime.entities.model_entities import ModelType +from models.account import Account from models.dataset import Dataset, ExternalKnowledgeBindings from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory: @staticmethod def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: """Create a mock current user.""" - current_user = Mock() + current_user = create_autospec(Account, instance=True) current_user.current_tenant_id = tenant_id return current_user @@ -135,7 +136,9 @@ class TestDatasetServiceUpdateDataset: "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, - patch("services.dataset_service.current_user") as mock_current_user, + patch( + "services.dataset_service.current_user", create_autospec(Account, instance=True) + ) as mock_current_user, ): mock_current_user.current_tenant_id = "tenant-123" yield { diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 0fc36510b9..ad65175e89 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -1,9 +1,10 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock, create_autospec, patch import pytest from flask_restx import reqparse from werkzeug.exceptions import BadRequest +from models.account import Account from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService @@ -35,19 +36,21 @@ class TestMetadataBugCompleteValidation: mock_metadata_args.name = None mock_metadata_args.type = "string" - with patch("services.metadata_service.current_user") as mock_user: - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + with patch("services.metadata_service.current_user", mock_user): # Should crash with TypeError with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) # Test update method as well - with patch("services.metadata_service.current_user") as mock_user: - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + with patch("services.metadata_service.current_user", mock_user): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index 7f6344f942..d151100cf3 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -1,8 +1,9 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock, create_autospec, patch import pytest from flask_restx import reqparse +from models.account import Account from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService @@ -24,20 +25,22 @@ class TestMetadataNullableBug: mock_metadata_args.name = None # This will cause len() to crash mock_metadata_args.type = "string" - with patch("services.metadata_service.current_user") as mock_user: - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + with patch("services.metadata_service.current_user", mock_user): # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) def test_metadata_service_update_with_none_name_crashes(self): """Test that MetadataService.update_metadata_name crashes when name is None.""" - with patch("services.metadata_service.current_user") as mock_user: - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + with patch("services.metadata_service.current_user", mock_user): # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) @@ -81,10 +84,11 @@ class TestMetadataNullableBug: mock_metadata_args.name = None # From args["name"] mock_metadata_args.type = None # From args["type"] - with patch("services.metadata_service.current_user") as mock_user: - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + with patch("services.metadata_service.current_user", mock_user): # Step 4: Service layer crashes on len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args)