diff --git a/.github/ISSUE_TEMPLATE/chore.yaml b/.github/ISSUE_TEMPLATE/chore.yaml new file mode 100644 index 0000000000..43449ef942 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/chore.yaml @@ -0,0 +1,27 @@ +name: "✨ Refactor" +description: Refactor existing code for improved readability and maintainability. +title: "[Chore/Refactor] " +labels: + - refactor +body: + - type: textarea + id: description + attributes: + label: Description + placeholder: "Describe the refactor you are proposing." + validations: + required: true + - type: textarea + id: motivation + attributes: + label: Motivation + placeholder: "Explain why this refactor is necessary." + validations: + required: false + - type: textarea + id: additional-context + attributes: + label: Additional Context + placeholder: "Add any other context or screenshots about the request here." + validations: + required: false diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index b35283e6ec..c2240d03ef 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -7,6 +7,7 @@ on: - "deploy/dev" - "deploy/enterprise" - "build/**" + - "release/e-*" - "deploy/rag-dev" tags: - "*" diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 007b1f6d3d..ee6011cd65 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -225,14 +225,15 @@ class AnnotationBatchImportApi(Resource): raise Forbidden() app_id = str(app_id) - # get file from request - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + + # get file from request + file = request.files["file"] # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 6d9f794307..ad62bd6e08 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource): # filter out apps that user doesn't have access to if FeatureService.get_system_features().webapp_auth.enabled: user_id = current_user.id - res = [] app_ids = [installed_app["app"].id for installed_app in installed_app_list] webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) + + # Pre-filter out apps without setting or with sso_verified + filtered_installed_apps = [] + app_id_to_app_code = {} + for installed_app in installed_app_list: - webapp_setting = webapp_settings.get(installed_app["app"].id) - if not webapp_setting: + app_id = installed_app["app"].id + webapp_setting = webapp_settings.get(app_id) + if not webapp_setting or webapp_setting.access_mode == "sso_verified": continue - if webapp_setting.access_mode == "sso_verified": - continue - app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) - if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( - user_id=user_id, - app_code=app_code, - ): + app_code = AppService.get_app_code_by_id(str(app_id)) + app_id_to_app_code[app_id] = app_code + filtered_installed_apps.append(installed_app) + + app_codes = list(app_id_to_app_code.values()) + + # Batch permission check + permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( + user_id=user_id, + app_codes=app_codes, + ) + + # Keep only allowed apps + res = [] + for installed_app in filtered_installed_apps: + app_id = installed_app["app"].id + app_code = app_id_to_app_code[app_id] + if permissions.get(app_code): res.append(installed_app) + installed_app_list = res logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id) diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py rename to api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7375726fa9..6f32498b42 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -208,6 +208,7 @@ class BasePluginClient: except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) + logger.error("Error in stream reponse for plugin %s", rep.__dict__) self._handle_plugin_daemon_error(error.error_type, error.message) raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}") if rep.data is None: diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 8b660c807d..8ecc2e2147 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -2,6 +2,8 @@ from collections.abc import Mapping from pydantic import TypeAdapter +from extensions.ext_logging import get_request_id + class PluginDaemonError(Exception): """Base class for all plugin daemon errors.""" @@ -11,7 +13,7 @@ class PluginDaemonError(Exception): def __str__(self) -> str: # returns the class name and description - return f"{self.__class__.__name__}: {self.description}" + return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}" class PluginDaemonInternalError(PluginDaemonError): diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index bcaf299892..d654463be9 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -5,14 +5,13 @@ from __future__ import annotations from typing import Any, Optional from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, Literal, RecursiveCharacterTextSplitter, Set, - TokenTextSplitter, Union, ) @@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] - if issubclass(cls, TokenTextSplitter): - extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", - "allowed_special": allowed_special, - "disallowed_special": disallowed_special, - } - kwargs = {**kwargs, **extra_kwargs} - return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 35e16b5c8f..d6961cdaa4 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -20,9 +20,6 @@ class Tool(ABC): The base class of a tool """ - entity: ToolEntity - runtime: ToolRuntime - def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: self.entity = entity self.runtime = runtime diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 724a2291c6..84efefba07 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -20,8 +20,6 @@ class BuiltinTool(Tool): :param meta: the meta data of a tool call processing """ - provider: str - def __init__(self, provider: str, **kwargs): super().__init__(**kwargs) self.provider = provider diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 10653b9948..333ef2834c 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -21,9 +21,6 @@ API_TOOL_DEFAULT_TIMEOUT = ( class ApiTool(Tool): - api_bundle: ApiToolBundle - provider_id: str - """ Api tool """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index d1bacbc735..8ebbb6b0fe 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient from core.mcp.types import ImageContent, TextContent from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType class MCPTool(Tool): - tenant_id: str - icon: str - runtime_parameters: Optional[list[ToolParameter]] - server_url: str - provider_id: str - def __init__( self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str ) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id self.icon = icon - self.runtime_parameters = None self.server_url = server_url self.provider_id = provider_id diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index aef2677c36..db38c10e81 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too class PluginTool(Tool): - tenant_id: str - icon: str - plugin_unique_identifier: str - runtime_parameters: Optional[list[ToolParameter]] - def __init__( self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str ) -> None: @@ -21,7 +16,7 @@ class PluginTool(Tool): self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters = None + self.runtime_parameters: Optional[list[ToolParameter]] = None def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index ec0575f6c3..d58807e29f 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas class DatasetRetrieverTool(Tool): - retrieval_tool: DatasetRetrieverBaseTool - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: super().__init__(entity, runtime) self.retrieval_tool = retrieval_tool diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index db6b84082f..6824e5e0e8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -25,15 +25,6 @@ logger = logging.getLogger(__name__) class WorkflowTool(Tool): - workflow_app_id: str - version: str - workflow_entities: dict[str, Any] - workflow_call_depth: int - thread_pool_id: Optional[str] = None - workflow_as_tool_id: str - - label: str - """ Workflow tool. """ diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b027a165f9..a8f025a750 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -136,6 +136,8 @@ def init_app(app: DifyApp): from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.flask import FlaskInstrumentor + from opentelemetry.instrumentation.redis import RedisInstrumentor + from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider from opentelemetry.propagate import set_global_textmap @@ -234,6 +236,8 @@ def init_app(app: DifyApp): CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() instrument_exception_logging() init_sqlalchemy_instrumentor(app) + RedisInstrumentor().instrument() + RequestsInstrumentor().instrument() atexit.register(shutdown_tracer) diff --git a/api/models/workflow.py b/api/models/workflow.py index a83d7d07c5..ba7396e0a2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -895,6 +895,19 @@ class WorkflowAppLog(Base): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "workflow_id": self.workflow_id, + "workflow_run_id": self.workflow_run_id, + "created_from": self.created_from, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + } + class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" diff --git a/api/pyproject.toml b/api/pyproject.toml index d8f663ef8d..9d979eca1c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -49,6 +49,8 @@ dependencies = [ "opentelemetry-instrumentation==0.48b0", "opentelemetry-instrumentation-celery==0.48b0", "opentelemetry-instrumentation-flask==0.48b0", + "opentelemetry-instrumentation-redis==0.48b0", + "opentelemetry-instrumentation-requests==0.48b0", "opentelemetry-instrumentation-sqlalchemy==0.48b0", "opentelemetry-propagator-b3==1.27.0", # opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0), diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index d057a14afb..b28afcaa41 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant -from models.model import App, Conversation, Message +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from models.workflow import WorkflowAppLog from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService @@ -21,6 +33,85 @@ logger = logging.getLogger(__name__) class ClearFreePlanTenantExpiredLogs: + @classmethod + def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None: + """ + Clean up message-related tables to avoid data redundancy. + This method cleans up tables that have foreign key relationships with Message. + + Args: + session: Database session, the same with the one in process_tenant method + tenant_id: Tenant ID for logging purposes + batch_message_ids: List of message IDs to clean up + """ + if not batch_message_ids: + return + + # Clean up each related table + related_tables = [ + (MessageFeedback, "message_feedbacks"), + (MessageFile, "message_files"), + (MessageAnnotation, "message_annotations"), + (MessageChain, "message_chains"), + (MessageAgentThought, "message_agent_thoughts"), + (AppAnnotationHitHistory, "app_annotation_hit_histories"), + (SavedMessage, "saved_messages"), + ] + + for model, table_name in related_tables: + # Query records related to expired messages + records = ( + session.query(model) + .filter( + model.message_id.in_(batch_message_ids), # type: ignore + ) + .all() + ) + + if len(records) == 0: + continue + + # Save records before deletion + record_ids = [record.id for record in records] + try: + record_data = [] + for record in records: + try: + if hasattr(record, "to_dict"): + record_data.append(record.to_dict()) + else: + # if record doesn't have to_dict method, we need to transform it to dict manually + record_dict = {} + for column in record.__table__.columns: + record_dict[column.name] = getattr(record, column.name) + record_data.append(record_dict) + except Exception: + logger.exception("Failed to transform %s record: %s", table_name, record.id) + continue + + if record_data: + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(record_data), + ).encode("utf-8"), + ) + except Exception: + logger.exception("Failed to save %s records", table_name) + + session.query(model).filter( + model.id.in_(record_ids), # type: ignore + ).delete(synchronize_session=False) + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(record_ids)} " + f"{table_name} records for tenant {tenant_id}" + ) + ) + @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): @@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs: Message.id.in_(message_ids), ).delete(synchronize_session=False) + cls._clear_message_related_tables(session, tenant_id, message_ids) session.commit() click.echo( @@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs: if len(workflow_runs) < batch: break + while True: + with Session(db.engine).no_autoflush as session: + workflow_app_logs = ( + session.query(WorkflowAppLog) + .filter( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), + ) + .limit(batch) + .all() + ) + + if len(workflow_app_logs) == 0: + break + + # save workflow app logs + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs], + ), + ).encode("utf-8"), + ) + + workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] + + # delete workflow app logs + session.query(WorkflowAppLog).filter( + WorkflowAppLog.id.in_(workflow_app_log_ids), + ).delete(synchronize_session=False) + session.commit() + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}" + f" workflow app logs for tenant {tenant_id}" + ) + ) + @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): """ diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 54d45f45ea..f8612456d6 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -52,6 +52,16 @@ class EnterpriseService: return data.get("result", False) + @classmethod + def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]): + if not app_codes: + return {} + body = {"userId": user_id, "appCodes": app_codes} + data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body) + if not data: + raise ValueError("No data found.") + return data.get("permissions", {}) + @classmethod def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: if not app_id: diff --git a/api/tests/test_containers_integration_tests/services/__init__.py b/api/tests/test_containers_integration_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py new file mode 100644 index 0000000000..3d7be0df7d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -0,0 +1,3340 @@ +import json +from hashlib import sha256 +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import Unauthorized + +from configs import dify_config +from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace +from models.account import AccountStatus, TenantAccountJoin +from services.account_service import AccountService, RegisterService, TenantService, TokenPair +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountLoginError, + AccountNotFoundError, + AccountPasswordError, + AccountRegisterError, + CurrentPasswordIncorrectError, +) +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError + + +class TestAccountService: + """Integration tests for AccountService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation and login with correct password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + assert account.email == email + assert account.status == AccountStatus.ACTIVE.value + + # Login with correct password + logged_in = AccountService.authenticate(email, password) + assert logged_in.id == account.id + + def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation without password (for OAuth users). + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + assert account.email == email + assert account.password is None + assert account.password_salt is None + + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation when registration is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks to disable registration + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = False + + with pytest.raises(AccountNotFound): # AccountNotFound exception + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=fake.password(length=12), + ) + + def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation when email is in freeze period. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + dify_config.BILLING_ENABLED = True + + with pytest.raises(AccountRegisterError): + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + dify_config.BILLING_ENABLED = False # Reset config for other tests + + def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with non-existent account. + """ + fake = Faker() + email = fake.email() + password = fake.password(length=12) + with pytest.raises(AccountNotFoundError): + AccountService.authenticate(email, password) + + def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with banned account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account first + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(AccountLoginError): + AccountService.authenticate(email, password) + + def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with wrong password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + correct_password = fake.password(length=12) + wrong_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account first + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=correct_password, + ) + + with pytest.raises(AccountPasswordError): + AccountService.authenticate(email, wrong_password) + + def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with invite token to set password for account without password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account without password + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Authenticate with invite token to set password + authenticated_account = AccountService.authenticate( + email, + new_password, + invite_token="valid_invite_token", + ) + + assert authenticated_account.id == account.id + assert authenticated_account.password is not None + assert authenticated_account.password_salt is not None + + def test_authenticate_pending_account_activation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test authentication activates pending account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account with pending status + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + account.status = AccountStatus.PENDING.value + from extensions.ext_database import db + + db.session.commit() + + # Authenticate should activate the account + authenticated_account = AccountService.authenticate(email, password) + assert authenticated_account.status == AccountStatus.ACTIVE.value + assert authenticated_account.initialized_at is not None + + def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful password update. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + # Update password + updated_account = AccountService.update_account_password(account, old_password, new_password) + + # Verify new password works + authenticated_account = AccountService.authenticate(email, new_password) + assert authenticated_account.id == account.id + + def test_update_account_password_wrong_current_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test password update with wrong current password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + wrong_password = fake.password(length=12) + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + with pytest.raises(CurrentPasswordIncorrectError): + AccountService.update_account_password(account, wrong_password, new_password) + + def test_update_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test password update with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.update_account_password(account, old_password, "123") + + def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation with automatic tenant creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + account = AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + assert account.email == email + + # Verify tenant was created and linked + from extensions.ext_database import db + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_create_account_and_tenant_workspace_creation_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account creation when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + with pytest.raises(WorkSpaceNotAllowedCreateError): + AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + def test_create_account_and_tenant_workspace_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account creation when workspace limit is exceeded. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + with pytest.raises(WorkspacesLimitExceededError): + AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test linking account with new OAuth provider. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Link with new provider + AccountService.link_account_integrate("new-google", "google_open_id_123", account) + + # Verify integration was created + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + assert integration is not None + assert integration.open_id == "google_open_id_123" + + def test_link_account_integrate_existing_provider( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test linking account with existing provider (should update). + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Link with provider first time + AccountService.link_account_integrate("exists-google", "google_open_id_123", account) + + # Link with same provider but different open_id (should update) + AccountService.link_account_integrate("exists-google", "google_open_id_456", account) + + # Verify integration was updated + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = ( + db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + ) + assert integration.open_id == "google_open_id_456" + + def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test closing an account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Close account + AccountService.close_account(account) + + # Verify account status changed + from extensions.ext_database import db + + db.session.refresh(account) + assert account.status == AccountStatus.CLOSED.value + + def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating account fields. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + updated_name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Update account fields + updated_account = AccountService.update_account(account, name=updated_name, interface_theme="dark") + + assert updated_account.name == updated_name + assert updated_account.interface_theme == "dark" + + def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating account with invalid field. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + with pytest.raises(AttributeError): + AccountService.update_account(account, invalid_field="value") + + def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating login information. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Update login info + AccountService.update_login_info(account, ip_address=ip_address) + + # Verify login info was updated + from extensions.ext_database import db + + db.session.refresh(account) + assert account.last_login_ip == ip_address + assert account.last_login_at is not None + + def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful login with token generation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login + token_pair = AccountService.login(account, ip_address=ip_address) + + assert isinstance(token_pair, TokenPair) + assert token_pair.access_token == "mock_access_token" + assert token_pair.refresh_token is not None + + # Verify passport service was called with correct parameters + mock_passport = mock_external_service_dependencies["passport_service"].return_value + mock_passport.issue.assert_called_once() + call_args = mock_passport.issue.call_args[0][0] + assert call_args["user_id"] == account.id + assert call_args["iss"] is not None + assert call_args["sub"] == "Console API Passport" + + def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test login activates pending account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account with pending status + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + account.status = AccountStatus.PENDING.value + from extensions.ext_database import db + + db.session.commit() + + # Login should activate the account + token_pair = AccountService.login(account) + + db.session.refresh(account) + assert account.status == AccountStatus.ACTIVE.value + + def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test logout functionality. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login first to get refresh token + token_pair = AccountService.login(account) + + # Logout + AccountService.logout(account=account) + + # Verify refresh token was deleted from Redis + from extensions.ext_redis import redis_client + + refresh_token_key = f"account_refresh_token:{account.id}" + assert redis_client.get(refresh_token_key) is None + + def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful token refresh. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "new_mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Login to get initial tokens + initial_token_pair = AccountService.login(account) + + # Refresh token + new_token_pair = AccountService.refresh_token(initial_token_pair.refresh_token) + + assert isinstance(new_token_pair, TokenPair) + assert new_token_pair.access_token == "new_mock_access_token" + assert new_token_pair.refresh_token != initial_token_pair.refresh_token + + def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test refresh token with invalid token. + """ + fake = Faker() + invalid_token = fake.uuid4() + with pytest.raises(ValueError, match="Invalid refresh token"): + AccountService.refresh_token(invalid_token) + + def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test refresh token with valid token but invalid account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login to get tokens + token_pair = AccountService.login(account) + + # Delete account + from extensions.ext_database import db + + db.session.delete(account) + db.session.commit() + + # Try to refresh token with deleted account + with pytest.raises(ValueError, match="Invalid account"): + AccountService.refresh_token(token_pair.refresh_token) + + def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading user by ID successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Load user + loaded_user = AccountService.load_user(account.id) + + assert loaded_user is not None + assert loaded_user.id == account.id + assert loaded_user.email == account.email + + def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading non-existent user. + """ + fake = Faker() + non_existent_user_id = fake.uuid4() + loaded_user = AccountService.load_user(non_existent_user_id) + assert loaded_user is None + + def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading banned user raises Unauthorized. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(Unauthorized): # Unauthorized exception + AccountService.load_user(account.id) + + def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test JWT token generation for account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_jwt_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate JWT token + token = AccountService.get_account_jwt_token(account) + + assert token == "mock_jwt_token" + + # Verify passport service was called with correct parameters + mock_passport = mock_external_service_dependencies["passport_service"].return_value + mock_passport.issue.assert_called_once() + call_args = mock_passport.issue.call_args[0][0] + assert call_args["user_id"] == account.id + assert call_args["iss"] is not None + assert call_args["sub"] == "Console API Passport" + + def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading logged in account by ID. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Load logged in account + loaded_account = AccountService.load_logged_in_account(account_id=account.id) + + assert loaded_account is not None + assert loaded_account.id == account.id + + def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through email successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Get user through email + found_user = AccountService.get_user_through_email(email) + + assert found_user is not None + assert found_user.id == account.id + + def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through non-existent email. + """ + fake = Faker() + non_existent_email = fake.email() + found_user = AccountService.get_user_through_email(non_existent_email) + assert found_user is None + + def test_get_user_through_email_banned_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting banned user through email raises Unauthorized. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(Unauthorized): # Unauthorized exception + AccountService.get_user_through_email(email) + + def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through email that is in freeze period. + """ + fake = Faker() + email_in_freeze = fake.email() + # Setup mocks + dify_config.BILLING_ENABLED = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + + with pytest.raises(AccountRegisterError): + AccountService.get_user_through_email(email_in_freeze) + + # Reset config + dify_config.BILLING_ENABLED = False + + def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account deletion (should add task to queue). + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + with patch("services.account_service.delete_account_task") as mock_delete_task: + # Delete account + AccountService.delete_account(account) + + # Verify task was added to queue + mock_delete_task.delay.assert_called_once_with(account.id) + + def test_generate_account_deletion_verification_code( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generating account deletion verification code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + assert token is not None + assert code is not None + assert len(code) == 6 + assert code.isdigit() + + def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test verifying valid account deletion code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + # Verify code + is_valid = AccountService.verify_account_deletion_code(token, code) + assert is_valid is True + + def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test verifying invalid account deletion code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + wrong_code = fake.numerify(text="######") + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + # Verify with wrong code + is_valid = AccountService.verify_account_deletion_code(token, wrong_code) + assert is_valid is False + + def test_verify_account_deletion_code_invalid_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test verifying account deletion code with invalid token. + """ + fake = Faker() + invalid_token = fake.uuid4() + invalid_code = fake.numerify(text="######") + is_valid = AccountService.verify_account_deletion_code(invalid_token, invalid_code) + assert is_valid is False + + +class TestTenantService: + """Integration tests for TenantService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + } + + def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant creation with default settings. + """ + fake = Faker() + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + assert tenant.name == tenant_name + assert tenant.plan == "basic" + assert tenant.status == "normal" + assert tenant.encrypt_public_key is not None + + def test_create_tenant_workspace_creation_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant creation when workspace creation is disabled. + """ + fake = Faker() + tenant_name = fake.company() + # Setup mocks to disable workspace creation + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception + TenantService.create_tenant(name=tenant_name) + + def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tenant creation with custom name and setup flag. + """ + fake = Faker() + custom_tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + # Create tenant with setup flag (should bypass workspace creation restriction) + tenant = TenantService.create_tenant(name=custom_tenant_name, is_setup=True, is_from_dashboard=True) + + assert tenant.name == custom_tenant_name + assert tenant.plan == "basic" + assert tenant.status == "normal" + assert tenant.encrypt_public_key is not None + + def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant member creation. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create tenant member + tenant_member = TenantService.create_tenant_member(tenant, account, role="admin") + + assert tenant_member.tenant_id == tenant.id + assert tenant_member.account_id == account.id + assert tenant_member.role == "admin" + + def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test creating duplicate owner for a tenant (should fail). + """ + fake = Faker() + tenant_name = fake.company() + email1 = fake.email() + name1 = fake.name() + password1 = fake.password(length=12) + email2 = fake.email() + name2 = fake.name() + password2 = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + account1 = AccountService.create_account( + email=email1, + name=name1, + interface_language="en-US", + password=password1, + ) + account2 = AccountService.create_account( + email=email2, + name=name2, + interface_language="en-US", + password=password2, + ) + + # Create first owner + TenantService.create_tenant_member(tenant, account1, role="owner") + + # Try to create second owner (should fail) + with pytest.raises(Exception, match="Tenant already has an owner"): + TenantService.create_tenant_member(tenant, account2, role="owner") + + def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating role for existing tenant member. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create member with initial role + tenant_member1 = TenantService.create_tenant_member(tenant, account, role="normal") + assert tenant_member1.role == "normal" + + # Update member role + tenant_member2 = TenantService.create_tenant_member(tenant, account, role="editor") + assert tenant_member2.tenant_id == tenant_member1.tenant_id + assert tenant_member2.account_id == tenant_member1.account_id + assert tenant_member2.role == "editor" + + def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting join tenants for an account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant1_name = fake.company() + tenant2_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenants + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + + # Add account to both tenants + TenantService.create_tenant_member(tenant1, account, role="normal") + TenantService.create_tenant_member(tenant2, account, role="admin") + + # Get join tenants + join_tenants = TenantService.get_join_tenants(account) + + assert len(join_tenants) == 2 + tenant_names = [tenant.name for tenant in join_tenants] + assert tenant1_name in tenant_names + assert tenant2_name in tenant_names + + def test_get_current_tenant_by_account_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting current tenant by account successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant = TenantService.create_tenant(name=tenant_name) + + # Add account to tenant and set as current + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + from extensions.ext_database import db + + db.session.commit() + + # Get current tenant + current_tenant = TenantService.get_current_tenant_by_account(account) + + assert current_tenant.id == tenant.id + assert current_tenant.name == tenant.name + assert current_tenant.role == "owner" + + def test_get_current_tenant_by_account_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting current tenant when account has no current tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account without setting current tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to get current tenant (should fail) + with pytest.raises(AttributeError): + TenantService.get_current_tenant_by_account(account) + + def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant switching. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant1_name = fake.company() + tenant2_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenants + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + + # Add account to both tenants + TenantService.create_tenant_member(tenant1, account, role="owner") + TenantService.create_tenant_member(tenant2, account, role="admin") + + # Set initial current tenant + account.current_tenant = tenant1 + from extensions.ext_database import db + + db.session.commit() + + # Switch to second tenant + TenantService.switch_tenant(account, tenant2.id) + + # Verify tenant was switched + db.session.refresh(account) + assert account.current_tenant_id == tenant2.id + + def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tenant switching without providing tenant ID. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to switch tenant without providing tenant ID + with pytest.raises(ValueError, match="Tenant ID must be provided"): + TenantService.switch_tenant(account, None) + + def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test switching to a tenant where account is not a member. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant = TenantService.create_tenant(name=tenant_name) + + # Try to switch to tenant where account is not a member + with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): + TenantService.switch_tenant(account, tenant.id) + + def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking if tenant has specific roles. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + admin_account = AccountService.create_account( + email=admin_email, + name=admin_name, + interface_language="en-US", + password=admin_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, admin_account, role="admin") + + # Check if tenant has owner role + from models.account import TenantAccountRole + + has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER]) + assert has_owner is True + + # Check if tenant has admin role + has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN]) + assert has_admin is True + + # Check if tenant has normal role (should be False) + has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) + assert has_normal is False + + def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking roles with invalid role type. + """ + fake = Faker() + tenant_name = fake.company() + invalid_role = fake.word() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + # Try to check roles with invalid role type + with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): + TenantService.has_roles(tenant, [invalid_role]) + + def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user role in a tenant. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant with specific role + TenantService.create_tenant_member(tenant, account, role="editor") + + # Get user role + user_role = TenantService.get_user_role(account, tenant) + + assert user_role == "editor" + + def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking member permission successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Check owner permission to add member (should succeed) + TenantService.check_member_permission(tenant, owner_account, member_account, "add") + + def test_check_member_permission_invalid_action( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test checking member permission with invalid action. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + invalid_action = fake.word() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to check permission with invalid action + with pytest.raises(Exception, match="Invalid action"): + TenantService.check_member_permission(tenant, account, None, invalid_action) + + def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking member permission when trying to operate self. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to check permission to operate self + with pytest.raises(Exception, match="Cannot operate self"): + TenantService.check_member_permission(tenant, account, account, "remove") + + def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful member removal from tenant. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Remove member + TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + + # Verify member was removed + from extensions.ext_database import db + from models.account import TenantAccountJoin + + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert member_join is None + + def test_remove_member_from_tenant_operate_self( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test removing member when trying to operate self. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to remove self + with pytest.raises(Exception, match="Cannot operate self"): + TenantService.remove_member_from_tenant(tenant, account, account) + + def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test removing member who is not in the tenant. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + non_member_email = fake.email() + non_member_name = fake.name() + non_member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + non_member_account = AccountService.create_account( + email=non_member_email, + name=non_member_name, + interface_language="en-US", + password=non_member_password, + ) + + # Add only owner to tenant + TenantService.create_tenant_member(tenant, owner_account, role="owner") + + # Try to remove non-member + with pytest.raises(Exception, match="Member not in tenant"): + TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) + + def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful member role update. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Update member role + TenantService.update_member_role(tenant, member_account, "admin", owner_account) + + # Verify role was updated + from extensions.ext_database import db + from models.account import TenantAccountJoin + + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert member_join.role == "admin" + + def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating member role to owner (should change current owner to admin). + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="admin") + + # Update member role to owner + TenantService.update_member_role(tenant, member_account, "owner", owner_account) + + # Verify roles were updated correctly + from extensions.ext_database import db + from models.account import TenantAccountJoin + + owner_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + ) + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert owner_join.role == "admin" + assert member_join.role == "owner" + + def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating member role to already assigned role. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="admin") + + # Try to update member role to already assigned role + with pytest.raises(Exception, match="The provided role is already assigned to the member"): + TenantService.update_member_role(tenant, member_account, "admin", owner_account) + + def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting tenant count successfully. + """ + fake = Faker() + tenant1_name = fake.company() + tenant2_name = fake.company() + tenant3_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create multiple tenants + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant3 = TenantService.create_tenant(name=tenant3_name) + + # Get tenant count + tenant_count = TenantService.get_tenant_count() + + # Should have at least 3 tenants (may be more from other tests) + assert tenant_count >= 3 + + def test_create_owner_tenant_if_not_exist_new_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant for new user without existing tenants. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + workspace_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create owner tenant + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + assert account.current_tenant is not None + assert account.current_tenant.name == workspace_name + + def test_create_owner_tenant_if_not_exist_existing_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant when user already has a tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + existing_tenant_name = fake.company() + new_workspace_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Create account and existing tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + existing_tenant = TenantService.create_tenant(name=existing_tenant_name) + TenantService.create_tenant_member(existing_tenant, account, role="owner") + account.current_tenant = existing_tenant + from extensions.ext_database import db + + db.session.commit() + + # Try to create owner tenant again (should not create new one) + TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) + + # Verify no new tenant was created + tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + assert len(tenant_joins) == 1 + assert account.current_tenant.id == existing_tenant.id + + def test_create_owner_tenant_if_not_exist_workspace_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + workspace_name = fake.company() + # Setup mocks to disable workspace creation + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to create owner tenant (should fail) + with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + + def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting tenant members successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + normal_email = fake.email() + normal_name = fake.name() + normal_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + admin_account = AccountService.create_account( + email=admin_email, + name=admin_name, + interface_language="en-US", + password=admin_password, + ) + normal_account = AccountService.create_account( + email=normal_email, + name=normal_name, + interface_language="en-US", + password=normal_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, admin_account, role="admin") + TenantService.create_tenant_member(tenant, normal_account, role="normal") + + # Get tenant members + members = TenantService.get_tenant_members(tenant) + + assert len(members) == 3 + member_emails = [member.email for member in members] + assert owner_email in member_emails + assert admin_email in member_emails + assert normal_email in member_emails + + # Verify roles are set correctly + for member in members: + if member.email == owner_email: + assert member.role == "owner" + elif member.email == admin_email: + assert member.role == "admin" + elif member.email == normal_email: + assert member.role == "normal" + + def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting dataset operator members successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + operator_email = fake.email() + operator_name = fake.name() + operator_password = fake.password(length=12) + normal_email = fake.email() + normal_name = fake.name() + normal_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + dataset_operator_account = AccountService.create_account( + email=operator_email, + name=operator_name, + interface_language="en-US", + password=operator_password, + ) + normal_account = AccountService.create_account( + email=normal_email, + name=normal_name, + interface_language="en-US", + password=normal_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, dataset_operator_account, role="dataset_operator") + TenantService.create_tenant_member(tenant, normal_account, role="normal") + + # Get dataset operator members + dataset_operators = TenantService.get_dataset_operator_members(tenant) + + assert len(dataset_operators) == 1 + assert dataset_operators[0].email == operator_email + assert dataset_operators[0].role == "dataset_operator" + + def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting custom config successfully. + """ + fake = Faker() + tenant_name = fake.company() + theme = fake.random_element(elements=("dark", "light")) + language = fake.random_element(elements=("zh-CN", "en-US")) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant with custom config + tenant = TenantService.create_tenant(name=tenant_name) + + # Set custom config + custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} + tenant.custom_config_dict = custom_config + from extensions.ext_database import db + + db.session.commit() + + # Get custom config + retrieved_config = TenantService.get_custom_config(tenant.id) + + assert retrieved_config == custom_config + assert retrieved_config["theme"] == theme + assert retrieved_config["language"] == language + assert retrieved_config["feature_flags"]["beta"] is True + + +class TestRegisterService: + """Integration tests for RegisterService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful system setup with account creation and tenant setup. + """ + fake = Faker() + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute setup + RegisterService.setup( + email=admin_email, + name=admin_name, + password=admin_password, + ip_address=ip_address, + ) + + # Verify account was created + from extensions.ext_database import db + from models.account import Account + from models.model import DifySetup + + account = db.session.query(Account).filter_by(email=admin_email).first() + assert account is not None + assert account.name == admin_name + assert account.last_login_ip == ip_address + assert account.initialized_at is not None + assert account.status == "active" + + # Verify DifySetup was created + dify_setup = db.session.query(DifySetup).first() + assert dify_setup is not None + + # Verify tenant was created and linked + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test setup failure with proper rollback of all created entities. + """ + fake = Faker() + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account to raise exception + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.side_effect = Exception("Database error") + + # Execute setup and verify exception + with pytest.raises(ValueError, match="Setup failed: Database error"): + RegisterService.setup( + email=admin_email, + name=admin_name, + password=admin_password, + ip_address=ip_address, + ) + + # Verify no entities were created (rollback worked) + from extensions.ext_database import db + from models.account import Account, Tenant, TenantAccountJoin + from models.model import DifySetup + + account = db.session.query(Account).filter_by(email=admin_email).first() + tenant_count = db.session.query(Tenant).count() + tenant_join_count = db.session.query(TenantAccountJoin).count() + dify_setup_count = db.session.query(DifySetup).count() + + assert account is None + assert tenant_count == 0 + assert tenant_join_count == 0 + assert dify_setup_count == 0 + + def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful account registration with workspace creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + assert account.current_tenant is not None + assert account.current_tenant.name == f"{name}'s Workspace" + + def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration with OAuth integration. + """ + fake = Faker() + email = fake.email() + name = fake.name() + open_id = fake.uuid4() + provider = fake.random_element(elements=("google", "github", "microsoft")) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration with OAuth + account = RegisterService.register( + email=email, + name=name, + password=None, + open_id=open_id, + provider=provider, + language=language, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify OAuth integration was created + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + assert integration is not None + assert integration.open_id == open_id + + def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration with pending status. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration with pending status + from models.account import AccountStatus + + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + status=AccountStatus.PENDING, + ) + + # Verify account was created with pending status + assert account.email == email + assert account.name == name + assert account.status == "pending" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created with no tenant + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration when workspace limit is exceeded. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created with no tenant + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration without workspace creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration without workspace creation + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + create_workspace_required=False, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify no tenant was created + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a new member who doesn't have an account yet. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + new_member_email = fake.email() + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + + # Execute invitation + token = RegisterService.invite_new_member( + tenant=tenant, + email=new_member_email, + language=language, + role="normal", + inviter=inviter, + ) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify email task was called + mock_send_mail.delay.assert_called_once() + + # Verify new account was created with pending status + from extensions.ext_database import db + from models.account import Account, TenantAccountJoin + + new_account = db.session.query(Account).filter_by(email=new_member_email).first() + assert new_account is not None + assert new_account.name == new_member_email.split("@")[0] # Default name from email + assert new_account.status == "pending" + + # Verify tenant member was created + tenant_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + ) + assert tenant_join is not None + assert tenant_join.role == "normal" + + def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting an existing member who is not in the tenant yet. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + existing_member_email = fake.email() + existing_member_name = fake.name() + existing_member_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account + existing_account = AccountService.create_account( + email=existing_member_email, + name=existing_member_name, + interface_language="en-US", + password=existing_member_password, + ) + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): + # Execute invitation + token = RegisterService.invite_new_member( + tenant=tenant, + email=existing_member_email, + language=language, + role="admin", + inviter=inviter, + ) + + # Verify email task was not called + mock_send_mail.delay.assert_not_called() + + # Verify tenant member was created for existing account + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + ) + assert tenant_join is not None + assert tenant_join.role == "admin" + + def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a member who is already in the tenant with pending status. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + existing_pending_member_email = fake.email() + existing_pending_member_name = fake.name() + existing_pending_member_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account with pending status + existing_account = AccountService.create_account( + email=existing_pending_member_email, + name=existing_pending_member_name, + interface_language="en-US", + password=existing_pending_member_password, + ) + existing_account.status = "pending" + from extensions.ext_database import db + + db.session.commit() + + # Add existing account to tenant + TenantService.create_tenant_member(tenant, existing_account, role="normal") + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + + # Execute invitation (should resend email for pending member) + token = RegisterService.invite_new_member( + tenant=tenant, + email=existing_pending_member_email, + language=language, + role="normal", + inviter=inviter, + ) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify email task was called + mock_send_mail.delay.assert_called_once() + + def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a member without providing an inviter. + """ + fake = Faker() + tenant_name = fake.company() + new_member_email = fake.email() + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + # Execute invitation without inviter (should fail) + with pytest.raises(ValueError, match="Inviter is required"): + RegisterService.invite_new_member( + tenant=tenant, + email=new_member_email, + language=language, + role="normal", + inviter=None, + ) + + def test_invite_new_member_account_already_in_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test inviting a member who is already in the tenant with active status. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + already_in_tenant_email = fake.email() + already_in_tenant_name = fake.name() + already_in_tenant_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account with active status + existing_account = AccountService.create_account( + email=already_in_tenant_email, + name=already_in_tenant_name, + interface_language="en-US", + password=already_in_tenant_password, + ) + existing_account.status = "active" + from extensions.ext_database import db + + db.session.commit() + + # Add existing account to tenant + TenantService.create_tenant_member(tenant, existing_account, role="normal") + + # Execute invitation (should fail for active member) + with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): + RegisterService.invite_new_member( + tenant=tenant, + email=already_in_tenant_email, + language=language, + role="normal", + inviter=inviter, + ) + + def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation of invite token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Execute token generation + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify token was stored in Redis + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + stored_data = redis_client.get(token_key) + assert stored_data is not None + + # Verify stored data contains correct information + import json + + invitation_data = json.loads(stored_data.decode("utf-8")) + assert invitation_data["account_id"] == str(account.id) + assert invitation_data["email"] == account.email + assert invitation_data["workspace_id"] == tenant.id + + def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation of valid invite token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Execute validation + is_valid = RegisterService.is_valid_invite_token(token) + + # Verify token is valid + assert is_valid is True + + def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation of invalid invite token. + """ + fake = Faker() + invalid_token = fake.uuid4() + # Execute validation with non-existent token + is_valid = RegisterService.is_valid_invite_token(invalid_token) + + # Verify token is invalid + assert is_valid is False + + def test_revoke_token_with_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test revoking token with workspace ID and email. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token exists in Redis before revocation + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + assert redis_client.get(token_key) is not None + + # Execute token revocation + RegisterService.revoke_token( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify token was not deleted from Redis + assert redis_client.get(token_key) is not None + + def test_revoke_token_without_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test revoking token without workspace ID and email. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token exists in Redis before revocation + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + assert redis_client.get(token_key) is not None + + # Execute token revocation without workspace and email + RegisterService.revoke_token( + workspace_id="", + email="", + token=token, + ) + + # Verify token was deleted from Redis + assert redis_client.get(token_key) is None + + def test_get_invitation_if_token_valid_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with valid token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + email_hash = sha256(account.email.encode()).hexdigest() + cache_key = f"member_invite_token:{tenant.id}, {email_hash}:{token}" + from extensions.ext_redis import redis_client + + redis_client.setex(cache_key, 24 * 60 * 60, account.id) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result contains expected data + assert result is not None + assert result["account"].id == account.id + assert result["tenant"].id == tenant.id + assert result["data"]["account_id"] == str(account.id) + assert result["data"]["email"] == account.email + assert result["data"]["workspace_id"] == tenant.id + + def test_get_invitation_if_token_valid_invalid_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with invalid token. + """ + fake = Faker() + workspace_id = fake.uuid4() + email = fake.email() + invalid_token = fake.uuid4() + # Execute invitation retrieval with invalid token + result = RegisterService.get_invitation_if_token_valid( + workspace_id=workspace_id, + email=email, + token=invalid_token, + ) + + # Verify result is None + assert result is None + + def test_get_invitation_if_token_valid_invalid_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with invalid tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + invalid_tenant_id = fake.uuid4() + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create a real token but with non-existent tenant ID + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": str(account.id), + "email": account.email, + "workspace_id": invalid_tenant_id, + } + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=invalid_tenant_id, + email=account.email, + token=token, + ) + + # Verify result is None (tenant not found) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_if_token_valid_account_mismatch( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with account ID mismatch. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Create a real token but with mismatched account ID + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": "different-account-id", # Different from actual account ID + "email": account.email, + "workspace_id": tenant.id, + } + token_key = RegisterService._get_invitation_token_key(token) + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result is None (account ID mismatch) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_if_token_valid_tenant_not_normal( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with tenant not in normal status. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Change tenant status to non-normal + tenant.status = "suspended" + from extensions.ext_database import db + + db.session.commit() + + # Create a real token + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": str(account.id), + "email": account.email, + "workspace_id": tenant.id, + } + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result is None (tenant not in normal status) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_by_token_with_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation by token with workspace ID and email. + """ + fake = Faker() + token = fake.uuid4() + workspace_id = fake.uuid4() + email = fake.email() + + # Create the cache key as the service does + from hashlib import sha256 + + from extensions.ext_redis import redis_client + + email_hash = sha256(email.encode()).hexdigest() + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" + + # Store account ID in Redis + account_id = fake.uuid4() + redis_client.setex(cache_key, 24 * 60 * 60, account_id) + + # Execute invitation retrieval + result = RegisterService._get_invitation_by_token( + token=token, + workspace_id=workspace_id, + email=email, + ) + + # Verify result contains expected data + assert result is not None + assert result["account_id"] == account_id + assert result["email"] == email + assert result["workspace_id"] == workspace_id + + # Clean up + redis_client.delete(cache_key) + + def test_get_invitation_by_token_without_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation by token without workspace ID and email. + """ + fake = Faker() + token = fake.uuid4() + invitation_data = { + "account_id": fake.uuid4(), + "email": fake.email(), + "workspace_id": fake.uuid4(), + } + + # Store invitation data in Redis using standard token key + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService._get_invitation_by_token(token=token) + + # Verify result contains expected data + assert result is not None + assert result["account_id"] == invitation_data["account_id"] + assert result["email"] == invitation_data["email"] + assert result["workspace_id"] == invitation_data["workspace_id"] + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting invitation token key. + """ + fake = Faker() + token = fake.uuid4() + # Execute token key generation + token_key = RegisterService._get_invitation_token_key(token) + + # Verify token key format + assert token_key == f"member_invite:token:{token}" diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..85a9355c79 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -0,0 +1,739 @@ +import pytest +from faker import Faker + +from core.variables.segments import StringSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from models import App, Workflow +from models.enums import DraftVariableType +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import ( + UpdateNotSupportedError, + WorkflowDraftVariableService, +) + + +class TestWorkflowDraftVariableService: + """ + Comprehensive integration tests for WorkflowDraftVariableService using testcontainers. + + This test class covers all major functionality of the WorkflowDraftVariableService: + - CRUD operations for workflow draft variables (Create, Read, Update, Delete) + - Variable listing and filtering by type (conversation, system, node) + - Variable updates and resets with proper validation + - Variable deletion operations at different scopes + - Special functionality like prefill and conversation ID retrieval + - Error handling for various edge cases and invalid operations + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """ + Mock setup for external service dependencies. + + WorkflowDraftVariableService doesn't have external dependencies that need mocking, + so this fixture returns an empty dictionary to maintain consistency with other test classes. + This ensures the test structure remains consistent across different service test files. + """ + # WorkflowDraftVariableService doesn't have external dependencies that need mocking + return {} + + def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + """ + Helper method to create a test app with realistic data for testing. + + This method creates a complete App instance with all required fields populated + using Faker for generating realistic test data. The app is configured for + workflow mode to support workflow draft variable testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies (unused in this service) + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + App: Created test app instance with all required fields populated + """ + fake = fake or Faker() + app = App() + app.id = fake.uuid4() + app.tenant_id = fake.uuid4() + app.name = fake.company() + app.description = fake.text() + app.mode = "workflow" + app.icon_type = "emoji" + app.icon = "🤖" + app.icon_background = "#FFEAD5" + app.enable_site = True + app.enable_api = True + app.created_by = fake.uuid4() + app.updated_by = app.created_by + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + return app + + def _create_test_workflow(self, db_session_with_containers, app, fake=None): + """ + Helper method to create a test workflow associated with an app. + + This method creates a Workflow instance using the proper factory method + to ensure all required fields are set correctly. The workflow is configured + as a draft version with basic graph structure for testing workflow variables. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: The app to associate the workflow with + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + Workflow: Created test workflow instance with proper configuration + """ + fake = fake or Faker() + workflow = Workflow.new( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=app.created_by, + environment_variables=[], + conversation_variables=[], + ) + from extensions.ext_database import db + + db.session.add(workflow) + db.session.commit() + return workflow + + def _create_test_variable( + self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None + ): + """ + Helper method to create a test workflow draft variable with proper configuration. + + This method creates different types of variables (conversation, system, node) using + the appropriate factory methods to ensure proper initialization. Each variable type + has specific requirements and this method handles the creation logic for all types. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app_id: ID of the app to associate the variable with + node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID) + name: Name of the variable for identification + value: StringSegment value for the variable content + variable_type: Type of variable ("conversation", "system", "node") determining creation method + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + WorkflowDraftVariable: Created test variable instance with proper type configuration + """ + fake = fake or Faker() + if variable_type == "conversation": + # Create conversation variable using the appropriate factory method + variable = WorkflowDraftVariable.new_conversation_variable( + app_id=app_id, + name=name, + value=value, + description=fake.text(max_nb_chars=20), + ) + elif variable_type == "system": + # Create system variable with editable flag and execution context + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + name=name, + value=value, + node_execution_id=fake.uuid4(), + editable=True, + ) + else: # node variable + # Create node variable with visibility and editability settings + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + node_id=node_id, + name=name, + value=value, + node_execution_id=fake.uuid4(), + visible=True, + editable=True, + ) + from extensions.ext_database import db + + db.session.add(variable) + db.session.commit() + return variable + + def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting a single variable by ID successfully. + + This test verifies that the service can retrieve a specific variable + by its ID and that the returned variable contains the correct data. + It ensures the basic CRUD read operation works correctly for workflow draft variables. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variable = service.get_variable(variable.id) + assert retrieved_variable is not None + assert retrieved_variable.id == variable.id + assert retrieved_variable.name == "test_var" + assert retrieved_variable.app_id == app.id + assert retrieved_variable.get_value().value == test_value.value + + def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting a variable that doesn't exist. + + This test verifies that the service returns None when trying to + retrieve a variable with a non-existent ID. This ensures proper + handling of missing data scenarios. + """ + fake = Faker() + non_existent_id = fake.uuid4() + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variable = service.get_variable(non_existent_id) + assert retrieved_variable is None + + def test_get_draft_variables_by_selectors_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting variables by selectors successfully. + + This test verifies that the service can retrieve multiple variables + using selector pairs (node_id, variable_name) and returns the correct + variables for each selector. This is useful for bulk variable retrieval + operations in workflow execution contexts. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + var1_value = StringSegment(value=fake.word()) + var2_value = StringSegment(value=fake.word()) + var3_value = StringSegment(value=fake.word()) + var1 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake + ) + var2 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake + ) + var3 = self._create_test_variable( + db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake + ) + selectors = [ + [CONVERSATION_VARIABLE_NODE_ID, "var1"], + [CONVERSATION_VARIABLE_NODE_ID, "var2"], + ["test_node_1", "var3"], + ] + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors) + assert len(retrieved_variables) == 3 + var_names = [var.name for var in retrieved_variables] + assert "var1" in var_names + assert "var2" in var_names + assert "var3" in var_names + for var in retrieved_variables: + if var.name == "var1": + assert var.get_value().value == var1_value.value + elif var.name == "var2": + assert var.get_value().value == var2_value.value + elif var.name == "var3": + assert var.get_value().value == var3_value.value + + def test_list_variables_without_values_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test listing variables without values successfully with pagination. + + This test verifies that the service can list variables with pagination + and that the returned variables don't include their values (for performance). + This is important for scenarios where only variable metadata is needed + without loading the actual content. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + for i in range(5): + test_value = StringSegment(value=fake.numerify("value##")) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_variables_without_values(app.id, page=1, limit=3) + assert result.total == 5 + assert len(result.variables) == 3 + assert result.variables[0].created_at >= result.variables[1].created_at + assert result.variables[1].created_at >= result.variables[2].created_at + for var in result.variables: + assert var.name is not None + assert var.app_id == app.id + + def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing variables for a specific node successfully. + + This test verifies that the service can filter and return only + variables associated with a specific node ID. This is crucial for + workflow execution where variables need to be scoped to specific nodes. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + node_id = fake.word() + var1_value = StringSegment(value=fake.word()) + var2_value = StringSegment(value=fake.word()) + var3_value = StringSegment(value=fake.word()) + self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake) + self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake) + self._create_test_variable( + db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_node_variables(app.id, node_id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == node_id + assert var.app_id == app.id + var_names = [var.name for var in result.variables] + assert "var1" in var_names + assert "var2" in var_names + assert "var3" not in var_names + + def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing conversation variables successfully. + + This test verifies that the service can filter and return only + conversation variables, excluding system and node variables. + Conversation variables are user-facing variables that can be + modified during conversation flows. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + conv_var1_value = StringSegment(value=fake.word()) + conv_var2_value = StringSegment(value=fake.word()) + conv_var1 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake + ) + conv_var2 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake + ) + sys_var_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_conversation_variables(app.id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == CONVERSATION_VARIABLE_NODE_ID + assert var.app_id == app.id + assert var.get_variable_type() == DraftVariableType.CONVERSATION + var_names = [var.name for var in result.variables] + assert "conv_var1" in var_names + assert "conv_var2" in var_names + assert "sys_var" not in var_names + + def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating a variable's name and value successfully. + + This test verifies that the service can update both the name and value + of an editable variable and that the changes are persisted correctly. + It also checks that the last_edited_at timestamp is updated appropriately. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + original_value = StringSegment(value=fake.word()) + new_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "original_name", + original_value, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + updated_variable = service.update_variable(variable, name="new_name", value=new_value) + assert updated_variable.name == "new_name" + assert updated_variable.get_value().value == new_value.value + assert updated_variable.last_edited_at is not None + from extensions.ext_database import db + + db.session.refresh(variable) + assert variable.name == "new_name" + assert variable.get_value().value == new_value.value + assert variable.last_edited_at is not None + + def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that updating a non-editable variable raises an exception. + + This test verifies that the service properly prevents updates to + variables that are not marked as editable. This is important for + maintaining data integrity and preventing unauthorized modifications + to system-controlled variables. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + original_value = StringSegment(value=fake.word()) + new_value = StringSegment(value=fake.word()) + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app.id, + name=fake.word(), # This is typically not editable + value=original_value, + node_execution_id=fake.uuid4(), + editable=False, # Set as non-editable + ) + from extensions.ext_database import db + + db.session.add(variable) + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + with pytest.raises(UpdateNotSupportedError) as exc_info: + service.update_variable(variable, name="new_name", value=new_value) + assert "variable not support updating" in str(exc_info.value) + assert variable.id in str(exc_info.value) + + def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test resetting conversation variable successfully. + + This test verifies that the service can reset a conversation variable + to its default value and clear the last_edited_at timestamp. + This functionality is useful for reverting user modifications + back to the original workflow configuration. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) + from core.variables.variables import StringVariable + + conv_var = StringVariable( + id=fake.uuid4(), + name="test_conv_var", + value="default_value", + selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], + ) + workflow.conversation_variables = [conv_var] + from extensions.ext_database import db + + db.session.commit() + modified_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "test_conv_var", + modified_value, + fake=fake, + ) + variable.last_edited_at = fake.date_time() + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + reset_variable = service.reset_variable(workflow, variable) + assert reset_variable is not None + assert reset_variable.get_value().value == "default_value" + assert reset_variable.last_edited_at is None + db.session.refresh(variable) + assert variable.get_value().value == "default_value" + assert variable.last_edited_at is None + + def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting a single variable successfully. + + This test verifies that the service can delete a specific variable + and that it's properly removed from the database. It ensures that + the deletion operation is atomic and complete. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + ) + from extensions.ext_database import db + + assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_variable(variable) + assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + + def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting all variables for a workflow successfully. + + This test verifies that the service can delete all variables + associated with a specific app/workflow. This is useful for + cleanup operations when workflows are deleted or reset. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + for i in range(3): + test_value = StringSegment(value=fake.numerify("value##")) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake + ) + other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + other_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake + ) + from extensions.ext_database import db + + app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + assert len(app_variables) == 3 + assert len(other_app_variables) == 1 + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_workflow_variables(app.id) + app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + assert len(app_variables_after) == 0 + assert len(other_app_variables_after) == 1 + + def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting all variables for a specific node successfully. + + This test verifies that the service can delete all variables + associated with a specific node while preserving variables + for other nodes and conversation variables. This is important + for node-specific cleanup operations in workflow management. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + node_id = fake.word() + for i in range(2): + test_value = StringSegment(value=fake.numerify("node_value##")) + self._create_test_variable( + db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake + ) + other_node_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake + ) + conv_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake + ) + from extensions.ext_database import db + + target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + other_node_variables = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + ) + conv_variables = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(target_node_variables) == 2 + assert len(other_node_variables) == 1 + assert len(conv_variables) == 1 + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_node_variables(app.id, node_id) + target_node_variables_after = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) + other_node_variables_after = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + ) + conv_variables_after = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(target_node_variables_after) == 0 + assert len(other_node_variables_after) == 1 + assert len(conv_variables_after) == 1 + + def test_prefill_conversation_variable_default_values_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test prefill conversation variable default values successfully. + + This test verifies that the service can automatically create + conversation variables with default values based on the workflow + configuration when none exist. This is important for initializing + workflow variables with proper defaults from the workflow definition. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) + from core.variables.variables import StringVariable + + conv_var1 = StringVariable( + id=fake.uuid4(), + name="conv_var1", + value="default_value1", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"], + ) + conv_var2 = StringVariable( + id=fake.uuid4(), + name="conv_var2", + value="default_value2", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], + ) + workflow.conversation_variables = [conv_var1, conv_var2] + from extensions.ext_database import db + + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + service.prefill_conversation_variable_default_values(workflow) + draft_variables = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(draft_variables) == 2 + var_names = [var.name for var in draft_variables] + assert "conv_var1" in var_names + assert "conv_var2" in var_names + for var in draft_variables: + assert var.app_id == app.id + assert var.node_id == CONVERSATION_VARIABLE_NODE_ID + assert var.editable is True + assert var.get_variable_type() == DraftVariableType.CONVERSATION + + def test_get_conversation_id_from_draft_variable_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting conversation ID from draft variable successfully. + + This test verifies that the service can extract the conversation ID + from a system variable named "conversation_id". This is important + for maintaining conversation context across workflow executions. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + conversation_id = fake.uuid4() + conv_id_value = StringSegment(value=conversation_id) + self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "conversation_id", + conv_id_value, + "system", + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + assert retrieved_conv_id == conversation_id + + def test_get_conversation_id_from_draft_variable_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting conversation ID when it doesn't exist. + + This test verifies that the service returns None when no + conversation_id variable exists for the app. This ensures + proper handling of missing conversation context scenarios. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + assert retrieved_conv_id is None + + def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing system variables successfully. + + This test verifies that the service can filter and return only + system variables, excluding conversation and node variables. + System variables are internal variables used by the workflow + engine for maintaining state and context. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + sys_var1_value = StringSegment(value=fake.word()) + sys_var2_value = StringSegment(value=fake.word()) + sys_var1 = self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake + ) + sys_var2 = self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake + ) + conv_var_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_system_variables(app.id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == SYSTEM_VARIABLE_NODE_ID + assert var.app_id == app.id + assert var.get_variable_type() == DraftVariableType.SYS + var_names = [var.name for var in result.variables] + assert "sys_var1" in var_names + assert "sys_var2" in var_names + assert "conv_var" not in var_names + + def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting variables by name successfully for different types. + + This test verifies that the service can retrieve variables by name + for different variable types (conversation, system, node). This + functionality is important for variable lookup operations during + workflow execution and user interactions. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + conv_var = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake + ) + sys_var = self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake + ) + node_var = self._create_test_variable( + db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var") + assert retrieved_conv_var is not None + assert retrieved_conv_var.name == "test_conv_var" + assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID + retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var") + assert retrieved_sys_var is not None + assert retrieved_sys_var.name == "test_sys_var" + assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID + retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var") + assert retrieved_node_var is not None + assert retrieved_node_var.name == "test_node_var" + assert retrieved_node_var.node_id == "test_node" + + def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting variables by name when they don't exist. + + This test verifies that the service returns None when trying to + retrieve variables by name that don't exist. This ensures proper + handling of missing variable scenarios for all variable types. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var") + assert retrieved_conv_var is None + retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var") + assert retrieved_sys_var is None + retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var") + assert retrieved_node_var is None diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py new file mode 100644 index 0000000000..dd2bc21814 --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -0,0 +1,168 @@ +import datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs + + +class TestClearFreePlanTenantExpiredLogs: + """Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method.""" + + @pytest.fixture + def mock_session(self): + """Create a mock database session.""" + session = Mock(spec=Session) + session.query.return_value.filter.return_value.all.return_value = [] + session.query.return_value.filter.return_value.delete.return_value = 0 + return session + + @pytest.fixture + def mock_storage(self): + """Create a mock storage object.""" + storage = Mock() + storage.save.return_value = None + return storage + + @pytest.fixture + def sample_message_ids(self): + """Sample message IDs for testing.""" + return ["msg-1", "msg-2", "msg-3"] + + @pytest.fixture + def sample_records(self): + """Sample records for testing.""" + records = [] + for i in range(3): + record = Mock() + record.id = f"record-{i}" + record.to_dict.return_value = { + "id": f"record-{i}", + "message_id": f"msg-{i}", + "created_at": datetime.datetime.now().isoformat(), + } + records.append(record) + return records + + def test_clear_message_related_tables_empty_message_ids(self, mock_session): + """Test that method returns early when message_ids is empty.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", []) + + # Should not call any database operations + mock_session.query.assert_not_called() + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): + """Test when no related records are found.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = [] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call query for each related table but find no records + assert mock_session.query.call_count > 0 + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_with_records_and_to_dict( + self, mock_session, sample_message_ids, sample_records + ): + """Test when records are found and have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call to_dict on each record (called once per table, so 7 times total) + for record in sample_records: + assert record.to_dict.call_count == 7 + + # Should save backup data + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids): + """Test when records are found but don't have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + # Create records without to_dict method + records = [] + for i in range(2): + record = Mock() + mock_table = Mock() + mock_id_column = Mock() + mock_id_column.name = "id" + mock_message_id_column = Mock() + mock_message_id_column.name = "message_id" + mock_table.columns = [mock_id_column, mock_message_id_column] + record.__table__ = mock_table + record.id = f"record-{i}" + record.message_id = f"msg-{i}" + del record.to_dict + records.append(record) + + # Mock records for first table only, empty for others + mock_session.query.return_value.filter.return_value.all.side_effect = [ + records, + [], + [], + [], + [], + [], + [], + ] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should save backup data even without to_dict + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_storage_error_continues( + self, mock_session, sample_message_ids, sample_records + ): + """Test that method continues even when storage.save fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_storage.save.side_effect = Exception("Storage error") + + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if backup fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): + """Test that method continues even when record serialization fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + + mock_session.query.return_value.filter.return_value.all.return_value = [record] + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if serialization fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): + """Test that deletion is called for found records.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call delete for each table that has records + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_logging_output( + self, mock_session, sample_message_ids, sample_records, capsys + ): + """Test that logging output is generated.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + pass diff --git a/api/uv.lock b/api/uv.lock index 4dced728ac..b00e7564f0 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", @@ -1265,6 +1265,8 @@ dependencies = [ { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-celery" }, { name = "opentelemetry-instrumentation-flask" }, + { name = "opentelemetry-instrumentation-redis" }, + { name = "opentelemetry-instrumentation-requests" }, { name = "opentelemetry-instrumentation-sqlalchemy" }, { name = "opentelemetry-propagator-b3" }, { name = "opentelemetry-proto" }, @@ -1448,6 +1450,8 @@ requires-dist = [ { name = "opentelemetry-instrumentation", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" }, { name = "opentelemetry-propagator-b3", specifier = "==1.27.0" }, { name = "opentelemetry-proto", specifier = "==1.27.0" }, @@ -3670,6 +3674,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" }, ] +[[package]] +name = "opentelemetry-instrumentation-redis" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" }, +] + +[[package]] +name = "opentelemetry-instrumentation-requests" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" }, +] + [[package]] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.48b0" diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index fe8e4602b7..b5ae4a425c 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -538,7 +538,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.5.0-beta + image: milvusdb/milvus:v2.5.15 profiles: - milvus command: [ 'milvus', 'run', 'standalone' ] diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 690dccb1a8..19910cca6f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1087,7 +1087,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.5.0-beta + image: milvusdb/milvus:v2.5.15 profiles: - milvus command: [ 'milvus', 'run', 'standalone' ] diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx new file mode 100644 index 0000000000..a3281be8eb --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -0,0 +1,156 @@ +import React from 'react' +import { render } from '@testing-library/react' +import '@testing-library/jest-dom' +import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' + +// Mock dependencies to isolate the SVG rendering issue +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('SVG Attribute Error Reproduction', () => { + // Capture console errors + const originalError = console.error + let errorMessages: string[] = [] + + beforeEach(() => { + errorMessages = [] + console.error = jest.fn((message) => { + errorMessages.push(message) + originalError(message) + }) + }) + + afterEach(() => { + console.error = originalError + }) + + it('should reproduce inkscape attribute errors when rendering OpikIconBig', () => { + console.log('\n=== TESTING OpikIconBig SVG ATTRIBUTE ERRORS ===') + + // Test multiple renders to check for inconsistency + for (let i = 0; i < 5; i++) { + console.log(`\nRender attempt ${i + 1}:`) + + const { unmount } = render() + + // Check for specific inkscape attribute errors + const inkscapeErrors = errorMessages.filter(msg => + typeof msg === 'string' && msg.includes('inkscape'), + ) + + if (inkscapeErrors.length > 0) { + console.log(`Found ${inkscapeErrors.length} inkscape errors:`) + inkscapeErrors.forEach((error, index) => { + console.log(` ${index + 1}. ${error.substring(0, 100)}...`) + }) + } + else { + console.log('No inkscape errors found in this render') + } + + unmount() + + // Clear errors for next iteration + errorMessages = [] + } + }) + + it('should analyze the SVG structure causing the errors', () => { + console.log('\n=== ANALYZING SVG STRUCTURE ===') + + // Import the JSON data directly + const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json') + + console.log('Icon structure analysis:') + console.log('- Root element:', iconData.icon.name) + console.log('- Children count:', iconData.icon.children?.length || 0) + + // Find problematic elements + const findProblematicElements = (node: any, path = '') => { + const problematicElements: any[] = [] + + if (node.name && (node.name.includes(':') || node.name.startsWith('sodipodi'))) { + problematicElements.push({ + path, + name: node.name, + attributes: Object.keys(node.attributes || {}), + }) + } + + // Check attributes for inkscape/sodipodi properties + if (node.attributes) { + const problematicAttrs = Object.keys(node.attributes).filter(attr => + attr.startsWith('inkscape:') || attr.startsWith('sodipodi:'), + ) + + if (problematicAttrs.length > 0) { + problematicElements.push({ + path, + name: node.name, + problematicAttributes: problematicAttrs, + }) + } + } + + if (node.children) { + node.children.forEach((child: any, index: number) => { + problematicElements.push( + ...findProblematicElements(child, `${path}/${node.name}[${index}]`), + ) + }) + } + + return problematicElements + } + + const problematicElements = findProblematicElements(iconData.icon, 'root') + + console.log(`\n🚨 Found ${problematicElements.length} problematic elements:`) + problematicElements.forEach((element, index) => { + console.log(`\n${index + 1}. Element: ${element.name}`) + console.log(` Path: ${element.path}`) + if (element.problematicAttributes) + console.log(` Problematic attributes: ${element.problematicAttributes.join(', ')}`) + }) + }) + + it('should test the normalizeAttrs function behavior', () => { + console.log('\n=== TESTING normalizeAttrs FUNCTION ===') + + const { normalizeAttrs } = require('@/app/components/base/icons/utils') + + const testAttributes = { + 'inkscape:showpageshadow': '2', + 'inkscape:pageopacity': '0.0', + 'inkscape:pagecheckerboard': '0', + 'inkscape:deskcolor': '#d1d1d1', + 'sodipodi:docname': 'opik-icon-big.svg', + 'xmlns:inkscape': 'https://www.inkscape.org/namespaces/inkscape', + 'xmlns:sodipodi': 'https://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd', + 'xmlns:svg': 'https://www.w3.org/2000/svg', + 'data-name': 'Layer 1', + 'normal-attr': 'value', + 'class': 'test-class', + } + + console.log('Input attributes:', Object.keys(testAttributes)) + + const normalized = normalizeAttrs(testAttributes) + + console.log('Normalized attributes:', Object.keys(normalized)) + console.log('Normalized values:', normalized) + + // Check if problematic attributes are still present + const problematicKeys = Object.keys(normalized).filter(key => + key.toLowerCase().includes('inkscape') || key.toLowerCase().includes('sodipodi'), + ) + + if (problematicKeys.length > 0) + console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`) + else + console.log('✅ No problematic attributes found after normalization') + }) +}) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 3d05575127..1ab40e31bf 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -1,12 +1,9 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useEffect, useRef, useState } from 'react' -import { - RiEqualizer2Line, -} from '@remixicon/react' +import React, { useCallback, useRef, useState } from 'react' + import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -17,13 +14,13 @@ type Props = { readOnly: boolean className?: string hasConfigured: boolean - controlShowPopup?: number + children?: React.ReactNode } & PopupProps const ConfigBtn: FC = ({ className, hasConfigured, - controlShowPopup, + children, ...popupProps }) => { const [open, doSetOpen] = useState(false) @@ -37,13 +34,6 @@ const ConfigBtn: FC = ({ setOpen(!openRef.current) }, [setOpen]) - useEffect(() => { - if (controlShowPopup) - // setOpen(!openRef.current) - setOpen(true) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [controlShowPopup]) - if (popupProps.readOnly && !hasConfigured) return null @@ -52,14 +42,11 @@ const ConfigBtn: FC = ({ open={open} onOpenChange={setOpen} placement='bottom-end' - offset={{ - mainAxis: 12, - crossAxis: hasConfigured ? 8 : 49, - }} + offset={12} > -
- +
+ {children}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index d082523222..7564a0f3c8 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -1,8 +1,9 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useEffect, useState } from 'react' import { RiArrowDownDoubleLine, + RiEqualizer2Line, } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { usePathname } from 'next/navigation' @@ -180,10 +181,6 @@ const Panel: FC = () => { })() }, []) - const [controlShowPopup, setControlShowPopup] = useState(0) - const showPopup = useCallback(() => { - setControlShowPopup(Date.now()) - }, [setControlShowPopup]) if (!isLoaded) { return (
@@ -196,46 +193,66 @@ const Panel: FC = () => { return (
-
- {!inUseTracingProvider && ( - <> + {!inUseTracingProvider && ( + +
{t(`${I18N_PREFIX}.title`)}
-
e.stopPropagation()}> - +
+
- - )} - {hasConfiguredTracing && ( - <> +
+ + )} + {hasConfiguredTracing && ( + +
@@ -243,33 +260,14 @@ const Panel: FC = () => {
{InUseProviderIcon && } - -
e.stopPropagation()}> - +
+
- - )} -
-
+ +
+
+ )} +
) } export default React.memo(Panel) diff --git a/web/app/(commonLayout)/datasets/create/page.tsx b/web/app/(commonLayout)/datasets/create/page.tsx index 663a830665..50fd1f5a19 100644 --- a/web/app/(commonLayout)/datasets/create/page.tsx +++ b/web/app/(commonLayout)/datasets/create/page.tsx @@ -1,9 +1,7 @@ import React from 'react' import DatasetUpdateForm from '@/app/components/datasets/create' -type Props = {} - -const DatasetCreation = async (props: Props) => { +const DatasetCreation = async () => { return ( ) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 463ae58ac2..7347caa2f9 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -88,7 +88,8 @@ const HeaderOptions: FC = ({ await clearAllAnnotations(appId) onAdded() } - catch (_) { + catch (e) { + console.error(`failed to clear all annotations, ${e}`) } finally { setShowClearConfirm(false) diff --git a/web/app/components/apps/footer.tsx b/web/app/components/apps/footer.tsx index 1646474876..c5efb2b8b4 100644 --- a/web/app/components/apps/footer.tsx +++ b/web/app/components/apps/footer.tsx @@ -39,10 +39,10 @@ const Footer = () => {

{t('app.join')}

{t('app.communityIntro')}

diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index e88d28879b..0f437c82b7 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -115,8 +115,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, []) useEffect(() => { - if (appData?.site.default_language) - changeLanguage(appData.site.default_language) + const setLocaleFromProps = async () => { + if (appData?.site.default_language) + await changeLanguage(appData.site.default_language) + } + setLocaleFromProps() }, [appData]) const [sidebarCollapseState, setSidebarCollapseState] = useState(false) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 4e86ad50e4..d7983dc599 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -101,15 +101,15 @@ export const useEmbeddedChatbot = () => { if (localeParam) { // If locale parameter exists in URL, use it instead of default - changeLanguage(localeParam) + await changeLanguage(localeParam) } else if (localeFromSysVar) { // If locale is set as a system variable, use that - changeLanguage(localeFromSysVar) + await changeLanguage(localeFromSysVar) } else if (appInfo?.site.default_language) { // Otherwise use the default from app config - changeLanguage(appInfo.site.default_language) + await changeLanguage(appInfo.site.default_language) } } diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 8e1b2148c5..d3c79a9f45 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -68,6 +68,7 @@ export const useFile = (fileConfig: FileUpload) => { } return true } + case SupportUploadFileTypes.custom: case SupportUploadFileTypes.document: { if (fileSize > docSizeLimit) { notify({ @@ -107,19 +108,6 @@ export const useFile = (fileConfig: FileUpload) => { } return true } - case SupportUploadFileTypes.custom: { - if (fileSize > docSizeLimit) { - notify({ - type: 'error', - message: t('common.fileUploader.uploadFromComputerLimit', { - type: SupportUploadFileTypes.document, - size: formatFileSize(docSizeLimit), - }), - }) - return false - } - return true - } default: { return true } @@ -231,7 +219,7 @@ export const useFile = (fileConfig: FileUpload) => { url: res.url, } if (!isAllowedFileExtension(res.name, res.mime_type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) { - notify({ type: 'error', message: `${t('common.fileUploader.fileExtensionNotSupport')} ${file.type}` }) + notify({ type: 'error', message: `${t('common.fileUploader.fileExtensionNotSupport')} ${newFile.type}` }) handleRemoveFile(uploadingFile.id) } if (!checkSizeLimit(newFile.supportFileType, newFile.size)) diff --git a/web/app/components/base/icons/utils.ts b/web/app/components/base/icons/utils.ts index 90d075f01c..632e362075 100644 --- a/web/app/components/base/icons/utils.ts +++ b/web/app/components/base/icons/utils.ts @@ -14,9 +14,26 @@ export type Attrs = { export function normalizeAttrs(attrs: Attrs = {}): Attrs { return Object.keys(attrs).reduce((acc: Attrs, key) => { + // Filter out editor metadata attributes before processing + if (key.startsWith('inkscape:') + || key.startsWith('sodipodi:') + || key.startsWith('xmlns:inkscape') + || key.startsWith('xmlns:sodipodi') + || key.startsWith('xmlns:svg') + || key === 'data-name') + return acc + const val = attrs[key] key = key.replace(/([-]\w)/g, (g: string) => g[1].toUpperCase()) key = key.replace(/([:]\w)/g, (g: string) => g[1].toUpperCase()) + + // Additional filter after camelCase conversion + if (key === 'xmlnsInkscape' + || key === 'xmlnsSodipodi' + || key === 'xmlnsSvg' + || key === 'dataName') + return acc + switch (key) { case 'class': acc.className = val diff --git a/web/app/components/base/tag-management/filter.tsx b/web/app/components/base/tag-management/filter.tsx index ecc159b2fc..4cf01fdc26 100644 --- a/web/app/components/base/tag-management/filter.tsx +++ b/web/app/components/base/tag-management/filter.tsx @@ -139,7 +139,10 @@ const TagFilter: FC = ({
-
setShowTagManagementModal(true)}> +
{ + setShowTagManagementModal(true) + setOpen(false) + }}>
{t('common.tag.manageTags')} diff --git a/web/app/components/datasets/list/doc.tsx b/web/app/components/datasets/list/doc.tsx index cc737caaf9..5a8262788b 100644 --- a/web/app/components/datasets/list/doc.tsx +++ b/web/app/components/datasets/list/doc.tsx @@ -87,7 +87,7 @@ const Doc = ({ apiBaseUrl }: DocProps) => {
{isTocExpanded ? ( -
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/components/datasets/list/template/template.zh.mdx b/web/app/components/datasets/list/template/template.zh.mdx index c21ce3bf5f..b7ea889a46 100644 --- a/web/app/components/datasets/list/template/template.zh.mdx +++ b/web/app/components/datasets/list/template/template.zh.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
@@ -1915,7 +1915,7 @@ ___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/components/develop/doc.tsx b/web/app/components/develop/doc.tsx index 65e6d4aec0..806ee72725 100644 --- a/web/app/components/develop/doc.tsx +++ b/web/app/components/develop/doc.tsx @@ -87,7 +87,7 @@ const Doc = ({ appDetail }: IDocProps) => {
{isTocExpanded ? ( -