diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b1e3813f33..efab1bc03c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,4 +1,5 @@ import flask_restx +from flask import Response from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus from sqlalchemy import select @@ -7,8 +8,7 @@ from werkzeug.exceptions import Forbidden from extensions.ext_database import db from libs.helper import TimestampField -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.model import ApiToken, App @@ -57,9 +57,9 @@ class BaseApiKeyListResource(Resource): def get(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + _, current_tenant_id = current_account_with_tenant() + + _get_resource(resource_id, current_tenant_id, self.resource_model) keys = db.session.scalars( select(ApiToken).where( ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id @@ -71,9 +71,8 @@ class BaseApiKeyListResource(Resource): def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + current_user, current_tenant_id = current_account_with_tenant() + _get_resource(resource_id, current_tenant_id, self.resource_model) if not current_user.has_edit_permission: raise Forbidden() @@ -93,7 +92,7 @@ class BaseApiKeyListResource(Resource): key = ApiToken.generate_api_key(self.token_prefix or "", 24) api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) - api_token.tenant_id = current_user.current_tenant_id + api_token.tenant_id = current_tenant_id api_token.token = key api_token.type = self.resource_type db.session.add(api_token) @@ -112,9 +111,8 @@ class BaseApiKeyResource(Resource): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + current_user, current_tenant_id = current_account_with_tenant() + _get_resource(resource_id, current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: @@ -158,7 +156,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - def after_request(self, resp): + def after_request(self, resp: Response): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" return resp @@ -208,7 +206,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - def after_request(self, resp): + def after_request(self, resp: Response): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" return resp @@ -229,7 +227,7 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - def after_request(self, resp): + def after_request(self, resp: Response): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" return resp diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index d0ee11fe75..589955738c 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,7 +1,6 @@ from typing import Literal from flask import request -from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden @@ -17,7 +16,7 @@ from fields.annotation_fields import ( annotation_fields, annotation_hit_history_fields, ) -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.annotation_service import AppAnnotationService @@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action: Literal["enable", "disable"]): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource): @login_required @account_initialization_required def get(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required def post(self, app_id, annotation_setting_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() job_id = str(job_id) @@ -160,7 +167,9 @@ class AnnotationApi(Resource): @login_required @account_initialization_required def get(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() page = request.args.get("page", default=1, type=int) @@ -199,7 +208,9 @@ class AnnotationApi(Resource): @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -214,7 +225,9 @@ class AnnotationApi(Resource): @login_required @account_initialization_required def delete(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -250,7 +263,9 @@ class AnnotationExportApi(Resource): @login_required @account_initialization_required def get(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource): @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource): @login_required @account_initialization_required def delete(self, app_id, annotation_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def post(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() job_id = str(job_id) @@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource): @login_required @account_initialization_required def get(self, app_id, annotation_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() page = request.args.get("page", default=1, type=int) diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 037561cfed..5751ff1f86 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,3 @@ -from typing import cast - -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,8 +10,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.app_fields import app_import_check_dependencies_fields, app_import_fields -from libs.login import login_required -from models import Account +from libs.login import current_account_with_tenant, login_required from models.model import App from services.app_dsl_service import AppDslService, ImportStatus from services.enterprise.enterprise_service import EnterpriseService @@ -32,7 +28,8 @@ class AppImportApi(Resource): @cloud_edition_billing_resource_check("apps") def post(self): # Check user role first - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -51,7 +48,7 @@ class AppImportApi(Resource): with Session(db.engine) as session: import_service = AppDslService(session) # Import app - account = cast(Account, current_user) + account = current_user result = import_service.import_app( account=account, import_mode=args["mode"], @@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource): @marshal_with(app_import_fields) def post(self, import_id): # Check user role first - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() # Create service with session with Session(db.engine) as session: import_service = AppDslService(session) # Confirm import - account = cast(Account, current_user) + account = current_user result = import_service.confirm_import(import_id=import_id, account=account) session.commit() @@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource): @account_initialization_required @marshal_with(app_import_check_dependencies_fields) def get(self, app_model: App): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() with Session(db.engine) as session: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 230ccdca15..4a9b6e7801 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from flask_login import current_user from flask_restx import Resource, fields, reqparse from controllers.console import api, console_ns @@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService @@ -48,11 +47,11 @@ class RuleGenerateApi(Resource): parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() - account = current_user try: rules = LLMGenerator.generate_rule_config( - tenant_id=account.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=args["no_variable"], @@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource): parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() - account = current_user try: code_result = LLMGenerator.generate_code( - tenant_id=account.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], code_language=args["code_language"], @@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource): parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() - account = current_user try: structured_output = LLMGenerator.generate_structured_output( - tenant_id=account.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], ) @@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource): parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("ideal_output", type=str, required=False, default="", location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() code_template = ( Python3CodeProvider.get_default_code() if args["language"] == "python" @@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource): match node_type: case "llm": return LLMGenerator.generate_rule_config( - current_user.current_tenant_id, + current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=True, ) case "agent": return LLMGenerator.generate_rule_config( - current_user.current_tenant_id, + current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=True, ) case "code": return LLMGenerator.generate_code( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], code_language=args["language"], @@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource): return {"error": f"invalid node type: {node_type}"} if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow return LLMGenerator.instruction_modify_legacy( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, flow_id=args["flow_id"], current=args["current"], instruction=args["instruction"], @@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource): ) if args["node_id"] != "" and args["current"] != "": # For workflow node return LLMGenerator.instruction_modify_workflow( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, flow_id=args["flow_id"], node_id=args["node_id"], current=args["current"], diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index fa6e3f8738..72ce8a7ddf 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -2,7 +2,6 @@ import json from typing import cast from flask import request -from flask_login import current_user from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden @@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from libs.login import login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService @@ -54,16 +52,14 @@ class ModelConfigResource(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" - if not isinstance(current_user, Account): - raise Forbidden() + current_user, current_tenant_id = current_account_with_tenant() if not current_user.has_edit_permission: raise Forbidden() - assert current_user.current_tenant_id is not None, "The tenant information should be loaded." # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, config=cast(dict, request.json), app_mode=AppMode.value_of(app_model.mode), ) @@ -95,12 +91,12 @@ class ModelConfigResource(Resource): # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, ) manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, @@ -134,7 +130,7 @@ class ModelConfigResource(Resource): else: try: tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, ) @@ -142,7 +138,7 @@ class ModelConfigResource(Resource): continue manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 95befc5df9..6537ea5a50 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import Account, Site @@ -76,9 +75,10 @@ class AppSite(Resource): @marshal_with(app_site_fields) def post(self, app_model): args = parse_app_site_args() + current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be editor, admin, or owner - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() site = db.session.query(Site).where(Site.app_id == app_model.id).first() @@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_fields) def post(self, app_model): # The role of the current user in the ta table must be admin or owner + current_user, _ = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 207303b212..d9ab7de29b 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,10 +1,9 @@ -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.auth.error import ApiKeyAuthFailedError -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService from ..wraps import account_initialization_required, setup_required @@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource): @login_required @account_initialization_required def get(self): - data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) if data_source_api_key_bindings: return { "sources": [ @@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource): @account_initialization_required def post(self): # The role of the current user in the table must be admin or owner + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() @@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource): args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: - ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) + ApiKeyAuthService.create_provider_auth(current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 @@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @account_initialization_required def delete(self, binding_id): # The role of the current user in the table must be admin or owner + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() - ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) + ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 6d9d675e87..058ef4408a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -3,7 +3,6 @@ from collections.abc import Generator from typing import cast from flask import request -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService from services.datasource_provider_service import DatasourceProviderService @@ -37,10 +36,12 @@ class DataSourceApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): + _, current_tenant_id = current_account_with_tenant() + # get workspace data source integrates data_source_integrates = db.session.scalars( select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.tenant_id == current_tenant_id, DataSourceOauthBinding.disabled == False, ) ).all() @@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource): @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = request.args.get("dataset_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, credential_id=credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", @@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource): documents = session.scalars( select(Document).filter_by( dataset_id=dataset_id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, data_source_type="notion_import", enabled=True, ) @@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource): datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id="langgenius/notion_datasource/notion_datasource", datasource_name="notion_datasource", - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, ) datasource_provider_service = DatasourceProviderService() @@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def get(self, workspace_id, page_id, page_type): + _, current_tenant_id = current_account_with_tenant() + credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, credential_id=credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", @@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource): notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=credential.get("integration_secret"), - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, ) text_docs = extractor.extract() @@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def post(self): + _, current_tenant_id = current_account_with_tenant() + parser = reqparse.RequestParser() parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") @@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource): "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, + "tenant_id": current_tenant_id, } ), document_model=args["doc_form"], @@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource): extract_settings.append(extract_setting) indexing_runner = IndexingRunner() response = indexing_runner.indexing_estimate( - current_user.current_tenant_id, + current_tenant_id, extract_settings, args["process_rule"], args["doc_form"], diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index d6bd02483d..d4d484a2e2 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,7 +1,6 @@ import uuid from flask import request -from flask_login import current_user from flask_restx import Resource, marshal, reqparse from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound @@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource): @login_required @account_initialization_required def get(self, dataset_id, document_id): + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = str(dataset_id) document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) @@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource): select(DocumentSegment) .where( DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id, + DocumentSegment.tenant_id == current_tenant_id, ) .order_by(DocumentSegment.position.asc()) ) @@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id): + current_user, _ = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action): + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: @@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): upload_file_id, dataset_id, document_id, - current_user.current_tenant_id, + current_tenant_id, current_user.id, ) except Exception as e: @@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource): @login_required @account_initialization_required def get(self, dataset_id, document_id, segment_id): + _, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource): db.session.query(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), - ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) @@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource): db.session.query(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), - ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 53b5a0d965..4c34df7eff 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,5 +1,4 @@ from flask import make_response, redirect, request -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler from libs.helper import StrLen -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService @@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): @login_required @account_initialization_required def get(self, provider_id: str): - user = current_user - tenant_id = user.current_tenant_id - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + tenant_id = current_tenant_id + if not current_user.has_edit_permission: raise Forbidden() credential_id = request.args.get("credential_id") @@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" authorization_url_response = oauth_handler.get_authorization_url( tenant_id=tenant_id, - user_id=user.id, + user_id=current_user.id, plugin_id=plugin_id, provider=provider_name, redirect_uri=redirect_uri, @@ -131,7 +131,9 @@ class DatasourceAuth(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -145,7 +147,7 @@ class DatasourceAuth(Resource): try: datasource_provider_service.add_datasource_api_key_provider( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider_id=datasource_provider_id, credentials=args["credentials"], name=args["name"], @@ -160,8 +162,10 @@ class DatasourceAuth(Resource): def get(self, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() + _, current_tenant_id = current_account_with_tenant() + datasources = datasource_provider_service.list_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, ) @@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): + current_user, current_tenant_id = current_account_with_tenant() + datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, auth_id=args["credential_id"], provider=provider_name, plugin_id=plugin_id, @@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): + current_user, current_tenant_id = current_account_with_tenant() + datasource_provider_id = DatasourceProviderID(provider_id) parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, auth_id=args["credential_id"], provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, @@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource): @login_required @account_initialization_required def get(self): + _, current_tenant_id = current_account_with_tenant() + datasource_provider_service = DatasourceProviderService() - datasources = datasource_provider_service.get_all_datasource_credentials( - tenant_id=current_user.current_tenant_id - ) + datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id) return {"result": jsonable_encoder(datasources)}, 200 @@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource): @login_required @account_initialization_required def get(self): + _, current_tenant_id = current_account_with_tenant() + datasource_provider_service = DatasourceProviderService() - datasources = datasource_provider_service.get_hard_code_datasource_credentials( - tenant_id=current_user.current_tenant_id - ) + datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id) return {"result": jsonable_encoder(datasources)}, 200 @@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") @@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.setup_oauth_custom_client_params( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, client_params=args.get("client_params", {}), enabled=args.get("enable_oauth_custom_client", False), @@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource): @login_required @account_initialization_required def delete(self, provider_id: str): + _, current_tenant_id = current_account_with_tenant() + datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_oauth_custom_client_params( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, ) return {"result": "success"}, 200 @@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, nullable=False, location="json") @@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.set_default_datasource_provider( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, credential_id=args["id"], ) @@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") @@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_provider_name( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, name=args["name"], credential_id=args["credential_id"], diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 404aa42073..b394887783 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import Resource, marshal, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,7 +12,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.dataset_fields import dataset_detail_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity @@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource): ) args = parser.parse_args() - + current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() @@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource): with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, ) if rag_pipeline_dataset_create_entity.permission == "partial_members": DatasetPermissionService.update_partial_member_list( - current_user.current_tenant_id, + current_tenant_id, import_info["dataset_id"], rag_pipeline_dataset_create_entity.partial_member_list, ) @@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") def post(self): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity( name="", description="", diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index c86c243c9b..7ead93a1b6 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import current_user, login_required -from models import Account, App, InstalledApp, RecommendedApp +from libs.login import current_account_with_tenant, login_required +from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): app_id = request.args.get("app_id", default=None, type=str) - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - current_tenant_id = current_user.current_tenant_id + current_user, current_tenant_id = current_account_with_tenant() if app_id: installed_apps = db.session.scalars( @@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("App not found") - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - current_tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() + app = db.session.query(App).where(App.id == args["app_id"]).first() if app is None: @@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource): """ def delete(self, installed_app): - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - if installed_app.app_owner_tenant_id == current_user.current_tenant_id: + _, current_tenant_id = current_account_with_tenant() + if installed_app.app_owner_tenant_id == current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index c6b3cf7515..dd618307e9 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -4,7 +4,7 @@ from constants import HIDDEN_VALUE from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant, current_user, login_required from models.account import Account from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService @@ -47,9 +47,7 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def get(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) @api.doc("create_api_based_extension") @@ -77,9 +75,10 @@ class APIBasedExtensionAPI(Resource): parser.add_argument("api_endpoint", type=str, required=True, location="json") parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() extension_data = APIBasedExtension( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, name=args["name"], api_endpoint=args["api_endpoint"], api_key=args["api_key"], @@ -102,7 +101,7 @@ class APIBasedExtensionDetailAPI(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None api_based_extension_id = str(id) - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) @@ -128,9 +127,9 @@ class APIBasedExtensionDetailAPI(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None api_based_extension_id = str(id) - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") @@ -157,9 +156,9 @@ class APIBasedExtensionDetailAPI(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None api_based_extension_id = str(id) - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) APIBasedExtensionService.delete(extension_data_from_db) diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 80847b8fef..39bcf3424c 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,7 +1,6 @@ from flask_restx import Resource, fields -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.feature_service import FeatureService from . import api, console_ns @@ -23,9 +22,9 @@ class FeatureApi(Resource): @cloud_utm_record def get(self): """Get feature configuration for current tenant""" - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - return FeatureService.get_features(current_user.current_tenant_id).model_dump() + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_features(current_tenant_id).model_dump() @console_ns.route("/system-features") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 4d4bb5d779..b053f222df 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -14,8 +14,7 @@ from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields -from libs.login import current_user -from models.account import Account +from libs.login import current_account_with_tenant from services.file_service import FileService from . import console_ns @@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource): content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: - assert isinstance(current_user, Account) - user = current_user + user, _ = current_account_with_tenant() upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 782bd72565..b31011b4a3 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -5,18 +5,10 @@ from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService -def _current_account_with_tenant() -> tuple[Account, str]: - assert isinstance(current_user, Account) - tenant_id = current_user.current_tenant_id - assert tenant_id is not None - return current_user, tenant_id - - @console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): @api.doc("create_endpoint") @@ -41,7 +33,7 @@ class EndpointCreateApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() @@ -87,7 +79,7 @@ class EndpointListApi(Resource): @login_required @account_initialization_required def get(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource): @login_required @account_initialization_required def get(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -255,7 +247,7 @@ class EndpointEnableApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -288,7 +280,7 @@ class EndpointDisableApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index dd6a878d87..4f080708cc 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields from libs.helper import extract_remote_ip -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError @@ -41,8 +41,7 @@ class MemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) @@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource): interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() inviter = current_user if not inviter.current_tenant: raise ValueError("No current tenant") @@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") member = db.session.query(Account).where(Account.id == str(member_id)).first() @@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) @@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) @@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource): ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - + current_user, _ = current_account_with_tenant() # check if the current user is the owner of the workspace - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): @@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource): parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() # check if the current user is the owner of the workspace - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): @@ -296,8 +286,7 @@ class OwnerTransfer(Resource): args = parser.parse_args() # check if the current user is the owner of the workspace - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 7012580362..acdd467b30 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,7 +1,6 @@ import io from flask import send_file -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value -from libs.login import login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -23,11 +21,8 @@ class ModelProviderListApi(Resource): @login_required @account_initialization_required def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() + tenant_id = current_tenant_id parser = reqparse.RequestParser() parser.add_argument( @@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() + tenant_id = current_tenant_id # if credential_id is not provided, return current used credential parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") @@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() @@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") try: model_provider_service.create_provider_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, credentials=args["credentials"], credential_name=args["name"], @@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() @@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") try: model_provider_service.update_provider_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, credentials=args["credentials"], credential_id=args["credential_id"], @@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( - tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] + tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"] ) return {"result": "success"}, 204 @@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") service = ModelProviderService() service.switch_active_provider_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"], ) @@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + tenant_id = current_tenant_id model_provider_service = ModelProviderService() @@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + tenant_id = current_tenant_id parser = reqparse.RequestParser() parser.add_argument( @@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): def get(self, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() BillingService.is_tenant_owner_or_admin(current_user) - if not current_user.current_tenant_id: - raise ValueError("No current tenant") data = BillingService.get_model_provider_payment_link( provider_name=provider, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, account_id=current_user.id, prefilled_email=current_user.email, ) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d38bb16ea7..d5d1aed00e 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -23,6 +22,8 @@ class DefaultModelApi(Resource): @login_required @account_initialization_required def get(self): + _, tenant_id = current_account_with_tenant() + parser = reqparse.RequestParser() parser.add_argument( "model_type", @@ -34,8 +35,6 @@ class DefaultModelApi(Resource): ) args = parser.parse_args() - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( tenant_id=tenant_id, model_type=args["model_type"] @@ -47,15 +46,14 @@ class DefaultModelApi(Resource): @login_required @account_initialization_required def post(self): + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() - - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() model_settings = args["model_settings"] for model_setting in model_settings: @@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def get(self, provider): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) @@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource): @account_initialization_required def post(self, provider: str): # To save the model's load balance configs + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource): raise ValueError("credential_id is required when configuring a custom-model") service = ModelProviderService() service.switch_active_custom_model_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="args") @@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() @@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource): parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() try: @@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() @@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource): try: model_provider_service.update_model_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() @@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource): model_provider_service = ModelProviderService() model_provider_service.remove_model_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() @@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource): service = ModelProviderService() service.add_model_credential_to_model_list( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource): @login_required @account_initialization_required def patch(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource): @login_required @account_initialization_required def patch(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource): parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( @@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource): @login_required @account_initialization_required def get(self, model_type): - tenant_id = current_user.current_tenant_id - + _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 7c70fb8aa0..ed5426376f 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -1,7 +1,6 @@ import io from flask import request, send_file -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_parameter_service import PluginParameterService @@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource): @account_initialization_required @plugin_permission_required(debug_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return { @@ -44,7 +43,7 @@ class PluginListApi(Resource): @login_required @account_initialization_required def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=False, location="args", default=1) parser.add_argument("page_size", type=int, required=False, location="args", default=256) @@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource): @login_required @account_initialization_required def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_ids", type=list, required=True, location="json") @@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() file = request.files["pkg"] @@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") @@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() file = request.files["bundle"] @@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") @@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") @@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") @@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") @@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") @@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self, task_id: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) @@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self, task_id: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} @@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} @@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self, task_id: str, identifier: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} @@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") @@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") @@ -466,7 +465,7 @@ class PluginUninstallApi(Resource): req.add_argument("plugin_installation_id", type=str, required=True, location="json") args = req.parse_args() - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} @@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource): @login_required @account_initialization_required def post(self): + current_user, current_tenant_id = current_account_with_tenant() user = current_user if not user.is_admin_or_owner: raise Forbidden() @@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource): install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) - tenant_id = user.current_tenant_id + tenant_id = current_tenant_id return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} @@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource): @login_required @account_initialization_required def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() permission = PluginPermissionService.get_permission(tenant_id) if not permission: @@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource): @account_initialization_required def get(self): # check if the user is admin or owner + current_user, tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id user_id = current_user.id parser = reqparse.RequestParser() @@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() @@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource): req.add_argument("auto_upgrade", type=dict, required=True, location="json") args = req.parse_args() - tenant_id = user.current_tenant_id - permission = args["permission"] install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) @@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource): @login_required @account_initialization_required def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() permission = PluginPermissionService.get_permission(tenant_id) permission_dict = { @@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource): @account_initialization_required def post(self): # exclude one single plugin - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() req = reqparse.RequestParser() req.add_argument("plugin_id", type=str, required=True, location="json") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 9285577f72..17a935ade7 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -2,7 +2,6 @@ import io from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_login import current_user from flask_restx import ( Resource, reqparse, @@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import CredentialType from libs.helper import StrLen, alphanumeric, uuid_value -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService @@ -53,10 +52,9 @@ class ToolProviderListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id req = reqparse.RequestParser() req.add_argument( @@ -78,9 +76,7 @@ class ToolBuiltinProviderListToolsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user = current_user - - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.list_builtin_tool_provider_tools( @@ -96,9 +92,7 @@ class ToolBuiltinProviderInfoApi(Resource): @login_required @account_initialization_required def get(self, provider): - user = current_user - - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) @@ -109,11 +103,10 @@ class ToolBuiltinProviderDeleteApi(Resource): @login_required @account_initialization_required def post(self, provider): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() - tenant_id = user.current_tenant_id req = reqparse.RequestParser() req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = req.parse_args() @@ -131,10 +124,9 @@ class ToolBuiltinProviderAddApi(Resource): @login_required @account_initialization_required def post(self, provider): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") @@ -161,13 +153,12 @@ class ToolBuiltinProviderUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") @@ -193,7 +184,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_credentials( @@ -218,13 +209,12 @@ class ToolApiProviderAddApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") @@ -258,10 +248,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -282,10 +271,9 @@ class ToolApiProviderListToolsApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -308,13 +296,12 @@ class ToolApiProviderUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") @@ -350,13 +337,12 @@ class ToolApiProviderDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -377,10 +363,9 @@ class ToolApiProviderGetApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -401,8 +386,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider, credential_type): - user = current_user - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.list_builtin_provider_credentials_schema( @@ -444,9 +428,9 @@ class ToolApiProviderPreviousTestApi(Resource): parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - + _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( - current_user.current_tenant_id, + current_tenant_id, args["provider_name"] or "", args["tool_name"], args["credentials"], @@ -462,13 +446,12 @@ class ToolWorkflowProviderCreateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") @@ -502,13 +485,12 @@ class ToolWorkflowProviderUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") @@ -545,13 +527,12 @@ class ToolWorkflowProviderDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") @@ -571,10 +552,9 @@ class ToolWorkflowProviderGetApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") @@ -606,10 +586,9 @@ class ToolWorkflowProviderListToolApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") @@ -631,10 +610,9 @@ class ToolBuiltinListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id return jsonable_encoder( [ @@ -653,8 +631,7 @@ class ToolApiListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( [ @@ -672,10 +649,9 @@ class ToolWorkflowListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id return jsonable_encoder( [ @@ -709,19 +685,18 @@ class ToolPluginOAuthApi(Resource): provider_name = tool_provider.provider_name # todo check permission - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() - tenant_id = user.current_tenant_id oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) if oauth_client_params is None: raise Forbidden("no oauth available client config found for this tool provider") oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( - user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" authorization_url_response = oauth_handler.get_authorization_url( @@ -800,11 +775,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource): @login_required @account_initialization_required def post(self, provider): + current_user, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return BuiltinToolManageService.set_default_provider( - tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] ) @@ -819,13 +795,13 @@ class ToolOAuthCustomClient(Resource): parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") args = parser.parse_args() - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() return BuiltinToolManageService.save_custom_oauth_client_params( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider=provider, client_params=args.get("client_params", {}), enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), @@ -835,20 +811,18 @@ class ToolOAuthCustomClient(Resource): @login_required @account_initialization_required def get(self, provider): + _, current_tenant_id = current_account_with_tenant() return jsonable_encoder( - BuiltinToolManageService.get_custom_oauth_client_params( - tenant_id=current_user.current_tenant_id, provider=provider - ) + BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider) ) @setup_required @login_required @account_initialization_required def delete(self, provider): + _, current_tenant_id = current_account_with_tenant() return jsonable_encoder( - BuiltinToolManageService.delete_custom_oauth_client_params( - tenant_id=current_user.current_tenant_id, provider=provider - ) + BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider) ) @@ -858,9 +832,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider): + _, current_tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( - tenant_id=current_user.current_tenant_id, provider_name=provider + tenant_id=current_tenant_id, provider_name=provider ) ) @@ -871,7 +846,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): @login_required @account_initialization_required def get(self, provider): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_credential_info( @@ -900,12 +875,12 @@ class ToolProviderMCPApi(Resource): ) parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) args = parser.parse_args() - user = current_user + user, tenant_id = current_account_with_tenant() if not is_valid_url(args["server_url"]): raise ValueError("Server URL is not valid.") return jsonable_encoder( MCPToolManageService.create_mcp_provider( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, server_url=args["server_url"], name=args["name"], icon=args["icon"], @@ -940,8 +915,9 @@ class ToolProviderMCPApi(Resource): pass else: raise ValueError("Server URL is not valid.") + _, current_tenant_id = current_account_with_tenant() MCPToolManageService.update_mcp_provider( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider_id=args["provider_id"], server_url=args["server_url"], name=args["name"], @@ -962,7 +938,8 @@ class ToolProviderMCPApi(Resource): parser = reqparse.RequestParser() parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) + _, current_tenant_id = current_account_with_tenant() + MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"]) return {"result": "success"} @@ -977,7 +954,7 @@ class ToolMCPAuthApi(Resource): parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") args = parser.parse_args() provider_id = args["provider_id"] - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) if not provider: raise ValueError("provider not found") @@ -1018,8 +995,8 @@ class ToolMCPDetailApi(Resource): @login_required @account_initialization_required def get(self, provider_id): - user = current_user - provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id) + _, tenant_id = current_account_with_tenant() + provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @@ -1029,8 +1006,7 @@ class ToolMCPListAllApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) @@ -1043,7 +1019,7 @@ class ToolMCPUpdateApi(Resource): @login_required @account_initialization_required def get(self, provider_id): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() tools = MCPToolManageService.list_mcp_tool_from_remote_server( tenant_id=tenant_id, provider_id=provider_id, diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 9e903d9286..4158f0524f 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -12,8 +12,8 @@ from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.login import current_user -from models.account import Account, AccountStatus +from libs.login import current_account_with_tenant +from models.account import AccountStatus from models.dataset import RateLimitLog from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus @@ -25,16 +25,13 @@ P = ParamSpec("P") R = TypeVar("R") -def _current_account() -> Account: - assert isinstance(current_user, Account) - return current_user - - def account_initialization_required(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): # check account initialization - account = _current_account() + current_user, _ = current_account_with_tenant() + + account = current_user if account.status == AccountStatus.UNINITIALIZED: raise AccountNotInitializedError() @@ -80,9 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]): def cloud_edition_billing_enabled(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if not features.billing.enabled: abort(403, "Billing feature is not enabled.") return view(*args, **kwargs) @@ -94,10 +90,8 @@ def cloud_edition_billing_resource_check(resource: str): def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - tenant_id = account.current_tenant_id - features = FeatureService.get_features(tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: members = features.members apps = features.apps @@ -138,9 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: if resource == "add_segment": if features.billing.subscription.plan == "sandbox": @@ -163,13 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): if resource == "knowledge": - account = _current_account() - assert account.current_tenant_id is not None - tenant_id = account.current_tenant_id - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id) + _, current_tenant_id = current_account_with_tenant() + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id) if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) - key = f"rate_limit_{tenant_id}" + key = f"rate_limit_{current_tenant_id}" redis_client.zadd(key, {current_time: current_time}) @@ -180,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str): if request_count > knowledge_rate_limit.limit: # add ratelimit record rate_limit_log = RateLimitLog( - tenant_id=tenant_id, + tenant_id=current_tenant_id, subscription_plan=knowledge_rate_limit.subscription_plan, operation="knowledge", ) @@ -200,17 +191,15 @@ def cloud_utm_record(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): with contextlib.suppress(Exception): - account = _current_account() - assert account.current_tenant_id is not None - tenant_id = account.current_tenant_id - features = FeatureService.get_features(tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: utm_info = request.cookies.get("utm_info") if utm_info: utm_info_dict: dict = json.loads(utm_info) - OperationService.record_utm(tenant_id, utm_info_dict) + OperationService.record_utm(current_tenant_id, utm_info_dict) return view(*args, **kwargs) @@ -289,9 +278,8 @@ def enable_change_email(view: Callable[P, R]): def is_allow_transfer_owner(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.is_allow_transfer_workspace: return view(*args, **kwargs) @@ -301,12 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]): return decorated -def knowledge_pipeline_publish_enabled(view): +def knowledge_pipeline_publish_enabled(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.knowledge_pipeline.publish_enabled: return view(*args, **kwargs) abort(403) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index d674c7467d..acbbf4531b 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,4 @@ from flask import request -from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound @@ -16,6 +15,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs @@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: str, document_id: str): + _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource): } ) def get(self, tenant_id: str, dataset_id: str, document_id: str): + _, current_tenant_id = current_account_with_tenant() """Get segments.""" # check dataset page = request.args.get("page", default=1, type=int) @@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource): segments, total = SegmentService.get_segments( document_id=document_id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, status_list=args["status"], keyword=args["keyword"], page=page, @@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource): ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) @@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource): } ) def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource): } ) def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): + _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check child chunk - child_chunk = SegmentService.get_child_chunk_by_id( - child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id - ) + child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id) if not child_chunk: raise NotFound("Child chunk not found.") @@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): + _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # get segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # get child chunk - child_chunk = SegmentService.get_child_chunk_by_id( - child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id - ) + child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/libs/login.py b/api/libs/login.py index 0535f52ea1..24b8c4011a 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -13,6 +13,15 @@ from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) + + +def current_account_with_tenant(): + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + assert current_user.current_tenant_id is not None, "The tenant information should be loaded." + return current_user, current_user.current_tenant_id + + from typing import ParamSpec, TypeVar P = ParamSpec("P") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 9feca7337f..c0d26cdd27 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import current_user -from models.account import Account +from libs.login import current_account_with_tenant from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from services.feature_service import FeatureService from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task @@ -24,10 +23,10 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - assert isinstance(current_user, Account) + current_user, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -63,12 +62,12 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , add annotation to index annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - assert current_user.current_tenant_id is not None + assert current_tenant_id is not None if annotation_setting: add_annotation_to_index_task.delay( annotation.id, args["question"], - current_user.current_tenant_id, + current_tenant_id, app_id, annotation_setting.collection_binding_id, ) @@ -86,13 +85,12 @@ class AppAnnotationService: enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(enable_app_annotation_job_key, "waiting") - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() enable_annotation_reply_task.delay( str(job_id), app_id, current_user.id, - current_user.current_tenant_id, + current_tenant_id, args["score_threshold"], args["embedding_provider_name"], args["embedding_model_name"], @@ -101,8 +99,7 @@ class AppAnnotationService: @classmethod def disable_app_annotation(cls, app_id: str): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: @@ -113,17 +110,16 @@ class AppAnnotationService: disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(disable_app_annotation_job_key, "waiting") - disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) + disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id) return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -153,11 +149,10 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -174,11 +169,10 @@ class AppAnnotationService: @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -196,7 +190,7 @@ class AppAnnotationService: add_annotation_to_index_task.delay( annotation.id, args["question"], - current_user.current_tenant_id, + current_tenant_id, app_id, annotation_setting.collection_binding_id, ) @@ -205,11 +199,10 @@ class AppAnnotationService: @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -234,7 +227,7 @@ class AppAnnotationService: update_annotation_to_index_task.delay( annotation.id, annotation.question, - current_user.current_tenant_id, + current_tenant_id, app_id, app_annotation_setting.collection_binding_id, ) @@ -244,11 +237,10 @@ class AppAnnotationService: @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -277,17 +269,16 @@ class AppAnnotationService: if app_annotation_setting: delete_annotation_index_task.delay( - annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id ) @classmethod def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -317,7 +308,7 @@ class AppAnnotationService: for annotation, annotation_setting in annotations_to_delete: if annotation_setting: delete_annotation_index_task.delay( - annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id + annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id ) # Step 4: Bulk delete annotations in a single query @@ -333,11 +324,10 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -354,7 +344,7 @@ class AppAnnotationService: if len(result) == 0: raise ValueError("The CSV file is empty.") # check annotation limit - features = FeatureService.get_features(current_user.current_tenant_id) + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: annotation_quota_limit = features.annotation_quota_limit if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size: @@ -364,21 +354,18 @@ class AppAnnotationService: indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") - batch_import_annotations_task.delay( - str(job_id), result, app_id, current_user.current_tenant_id, current_user.id - ) + batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id) except Exception as e: return {"error_msg": str(e)} return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() # get app info app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -445,12 +432,11 @@ class AppAnnotationService: @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() # get app info app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -481,12 +467,11 @@ class AppAnnotationService: @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() # get app info app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -531,11 +516,10 @@ class AppAnnotationService: @classmethod def clear_all_annotations(cls, app_id: str): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -558,7 +542,7 @@ class AppAnnotationService: # if annotation reply is enabled, delete annotation index if app_annotation_setting: delete_annotation_index_task.delay( - annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id ) db.session.delete(annotation) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 36b7084973..fcb6ab1d40 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,7 +3,6 @@ import time from collections.abc import Mapping from typing import Any -from flask_login import current_user from sqlalchemy.orm import Session from configs import dify_config @@ -18,6 +17,7 @@ from core.tools.entities.tool_entities import CredentialType from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.login import current_account_with_tenant from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -93,6 +93,8 @@ class DatasourceProviderService: """ get credential by id """ + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: if credential_id: datasource_provider = ( @@ -157,6 +159,8 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: datasource_providers = ( session.query(DatasourceProvider) @@ -604,6 +608,8 @@ class DatasourceProviderService: """ provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" with redis_client.lock(lock, timeout=20): @@ -901,6 +907,8 @@ class DatasourceProviderService: """ update datasource credentials. """ + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: datasource_provider = ( session.query(DatasourceProvider) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 3cb7424df8..4768b981cc 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -25,9 +25,7 @@ class TestAnnotationService: patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, - patch( - "services.annotation_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, + patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant, ): # Setup default mock returns mock_account_feature_service.get_features.return_value.billing.enabled = False @@ -38,6 +36,9 @@ class TestAnnotationService: mock_disable_task.delay.return_value = None mock_batch_import_task.delay.return_value = None + # Create mock user that will be returned by current_account_with_tenant + mock_user = create_autospec(Account, instance=True) + yield { "account_feature_service": mock_account_feature_service, "feature_service": mock_feature_service, @@ -47,7 +48,8 @@ class TestAnnotationService: "enable_task": mock_enable_task, "disable_task": mock_disable_task, "batch_import_task": mock_batch_import_task, - "current_user": mock_current_user, + "current_account_with_tenant": mock_current_account_with_tenant, + "current_user": mock_user, } def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): @@ -107,6 +109,11 @@ class TestAnnotationService: """ mock_external_service_dependencies["current_user"].id = account_id mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + # Configure current_account_with_tenant to return (user, tenant_id) + mock_external_service_dependencies["current_account_with_tenant"].return_value = ( + mock_external_service_dependencies["current_user"], + tenant_id, + ) def _create_test_conversation(self, app, account, fake): """ diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 5d132cb787..6777077de8 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -60,7 +60,7 @@ class TestAccountInitialization: return "success" # Act - with patch("controllers.console.wraps._current_account", return_value=mock_user): + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")): result = protected_view() # Assert @@ -77,7 +77,7 @@ class TestAccountInitialization: return "success" # Act & Assert - with patch("controllers.console.wraps._current_account", return_value=mock_user): + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")): with pytest.raises(AccountNotInitializedError): protected_view() @@ -163,7 +163,9 @@ class TestBillingResourceLimits: return "member_added" # Act - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123") + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = add_member() @@ -185,7 +187,10 @@ class TestBillingResourceLimits: # Act & Assert with app.test_request_context(): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: add_member() @@ -207,7 +212,10 @@ class TestBillingResourceLimits: # Test 1: Should reject when source is datasets with app.test_request_context("/?source=datasets"): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: upload_document() @@ -215,7 +223,10 @@ class TestBillingResourceLimits: # Test 2: Should allow when source is not datasets with app.test_request_context("/?source=other"): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = upload_document() assert result == "document_uploaded" @@ -239,7 +250,9 @@ class TestRateLimiting: return "knowledge_success" # Act - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123") + ): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): @@ -271,7 +284,10 @@ class TestRateLimiting: # Act & Assert with app.test_request_context(): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ):