From 19cc6ea9930092d3049bde8de8906664a93db86c Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 17 Oct 2025 10:10:16 +0900 Subject: [PATCH] fix 27003 (#27005) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/explore/workflow.py | 4 ++- api/controllers/console/wraps.py | 7 +++++- api/libs/login.py | 27 +++++++++++++++------ api/models/model.py | 2 +- api/services/datasource_provider_service.py | 21 ++++++++++------ 5 files changed, 43 insertions(+), 18 deletions(-) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index e32f2814eb..aeea446c6e 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -22,7 +22,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from libs import helper -from libs.login import current_user +from libs.login import current_user as current_user_ from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -31,6 +31,8 @@ from .. import console_ns logger = logging.getLogger(__name__) +current_user = current_user_._get_current_object() # type: ignore + @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 2fa28711c3..8572a6dc9b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -303,7 +303,12 @@ def edit_permission_required(f: Callable[P, R]): def decorated_function(*args: P.args, **kwargs: P.kwargs): from werkzeug.exceptions import Forbidden - current_user, _ = current_account_with_tenant() + from libs.login import current_user + from models import Account + + user = current_user._get_current_object() # type: ignore + if not isinstance(user, Account): + raise Forbidden() if not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/libs/login.py b/api/libs/login.py index 2c75ef9297..d0e81a3441 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Union, cast +from typing import Any from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS # type: ignore @@ -10,16 +10,21 @@ from configs import dify_config from models import Account from models.model import EndUser -#: A proxy for the current user. If no user is logged in, this will be an -#: anonymous user -current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) - def current_account_with_tenant(): - if not isinstance(current_user, Account): + """ + Resolve the underlying account for the current user proxy and ensure tenant context exists. + Allows tests to supply plain Account mocks without the LocalProxy helper. + """ + user_proxy = current_user + + get_current_object = getattr(user_proxy, "_get_current_object", None) + user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") - assert current_user.current_tenant_id is not None, "The tenant information should be loaded." - return current_user, current_user.current_tenant_id + assert user.current_tenant_id is not None, "The tenant information should be loaded." + return user, user.current_tenant_id from typing import ParamSpec, TypeVar @@ -81,3 +86,9 @@ def _get_user() -> EndUser | Account | None: return g._login_user # type: ignore return None + + +#: A proxy for the current user. If no user is logged in, this will be an +#: anonymous user +# NOTE: Any here, but use _get_current_object to check the fields +current_user: Any = LocalProxy(lambda: _get_user()) diff --git a/api/models/model.py b/api/models/model.py index 2373421e7d..af22ab9538 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1479,7 +1479,7 @@ class EndUser(Base, UserMixin): sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) type: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index fcb6ab1d40..1b690e2266 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -17,7 +17,6 @@ from core.tools.entities.tool_entities import CredentialType from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.login import current_account_with_tenant from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +def get_current_user(): + from libs.login import current_user + from models.account import Account + from models.model import EndUser + + if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore + raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + return current_user + + class DatasourceProviderService: """ Model Provider Service @@ -93,8 +102,6 @@ class DatasourceProviderService: """ get credential by id """ - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: if credential_id: datasource_provider = ( @@ -111,6 +118,7 @@ class DatasourceProviderService: return {} # refresh the credentials if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()): + current_user = get_current_user() decrypted_credentials = self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, datasource_provider=datasource_provider, @@ -159,8 +167,6 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: datasource_providers = ( session.query(DatasourceProvider) @@ -170,6 +176,7 @@ class DatasourceProviderService: ) if not datasource_providers: return [] + current_user = get_current_user() # refresh the credentials real_credentials_list = [] for datasource_provider in datasource_providers: @@ -608,7 +615,6 @@ class DatasourceProviderService: """ provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id - current_user, _ = current_account_with_tenant() with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" @@ -630,6 +636,7 @@ class DatasourceProviderService: raise ValueError("Authorization name is already exists") try: + current_user = get_current_user() self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, @@ -907,7 +914,6 @@ class DatasourceProviderService: """ update datasource credentials. """ - current_user, _ = current_account_with_tenant() with Session(db.engine) as session: datasource_provider = ( @@ -944,6 +950,7 @@ class DatasourceProviderService: for key, value in credentials.items() } try: + current_user = get_current_user() self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id,